Spaces:
Sleeping
Sleeping
Update auditron env server and training script
Browse filesCo-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
server.py
CHANGED
|
@@ -304,6 +304,12 @@ class AuditronEnv(Environment[AuditronAction, AuditronObservation, AuditronState
|
|
| 304 |
'Example: {"bid_price": 85, "actual_strength": 75}',
|
| 305 |
)
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
s.supplier_bids[agent_id] = {
|
| 308 |
"bid_price": bid_price,
|
| 309 |
"actual_strength": actual_strength,
|
|
|
|
| 304 |
'Example: {"bid_price": 85, "actual_strength": 75}',
|
| 305 |
)
|
| 306 |
|
| 307 |
+
# Clamp bid to at least production cost β penalize if model bid below it
|
| 308 |
+
production_cost = s.required_strength * s.supplier_costs.get(agent_id, 0)
|
| 309 |
+
if bid_price < production_cost:
|
| 310 |
+
s.supplier_rewards[agent_id] += PENALTY_INVALID_FORMAT / 2 # -2.5
|
| 311 |
+
bid_price = production_cost # clamp so game math stays valid
|
| 312 |
+
|
| 313 |
s.supplier_bids[agent_id] = {
|
| 314 |
"bid_price": bid_price,
|
| 315 |
"actual_strength": actual_strength,
|
train.py
CHANGED
|
@@ -25,7 +25,7 @@ from datetime import datetime
|
|
| 25 |
from trl import GRPOConfig, GRPOTrainer
|
| 26 |
|
| 27 |
# ββ Config (all tunables in one place) ββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
-
MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/Qwen2.5-
|
| 29 |
MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "4096"))
|
| 30 |
LORA_RANK = int(os.environ.get("LORA_RANK", "16"))
|
| 31 |
NUM_TRAINING_STEPS = int(os.environ.get("NUM_TRAINING_STEPS", "500"))
|
|
@@ -702,148 +702,154 @@ def main():
|
|
| 702 |
output_dir=OUTPUT_DIR,
|
| 703 |
)
|
| 704 |
|
| 705 |
-
# Periodic checkpoint callback β pauses every CHECKPOINT_EVERY steps to run
|
| 706 |
-
# a 1-episode eval and log per-personality profits for line charts in reports.
|
| 707 |
-
from transformers import TrainerCallback
|
| 708 |
-
|
| 709 |
-
class CheckpointEvalCallback(TrainerCallback):
|
| 710 |
-
def on_step_end(self, args, state, control, **kwargs):
|
| 711 |
-
if state.global_step % CHECKPOINT_EVERY == 0 and state.global_step > 0:
|
| 712 |
-
print(f"\n[Checkpoint eval at step {state.global_step}/{NUM_TRAINING_STEPS}]")
|
| 713 |
-
evaluate_model(model, tokenizer, num_episodes=1, eval_step=state.global_step, max_rounds=CHECKPOINT_ROUNDS)
|
| 714 |
-
# Switch back to training mode after eval
|
| 715 |
-
from unsloth import FastLanguageModel
|
| 716 |
-
FastLanguageModel.for_training(model)
|
| 717 |
-
tokenizer.padding_side = "left"
|
| 718 |
-
|
| 719 |
trainer = GRPOTrainer(
|
| 720 |
model=model,
|
| 721 |
processing_class=tokenizer,
|
| 722 |
reward_funcs=[format_reward, reasoning_reward, economic_reward],
|
| 723 |
args=training_args,
|
| 724 |
train_dataset=dataset,
|
| 725 |
-
callbacks=[CheckpointEvalCallback()],
|
| 726 |
)
|
| 727 |
|
| 728 |
print(f"\nStarting GRPO training ({NUM_TRAINING_STEPS} steps)...")
|
| 729 |
print(f"Logs: reasoning={REASONING_LOG} episodes={EPISODE_LOG} eval={EVAL_LOG}")
|
| 730 |
-
print(f"Checkpoint evals every {CHECKPOINT_EVERY} steps.")
|
| 731 |
trainer.train()
|
| 732 |
print("Training complete!")
|
| 733 |
|
| 734 |
-
# 4.
|
| 735 |
-
model.save_pretrained_merged(OUTPUT_DIR, tokenizer, save_method="merged_16bit")
|
| 736 |
-
print(f"Model saved to {OUTPUT_DIR}/")
|
| 737 |
-
|
| 738 |
-
# 5. Final eval β 1 full 50-round episode, rich per-round logging for charts
|
| 739 |
print("\n[Final eval β full 50-round episode]")
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
{"bid_price": round(req * cost * 1.1, 1), "actual_strength": req})))
|
| 763 |
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
buy_action = generate_action(model, tokenizer, build_prompt(buy_obs, "buyer"))
|
| 777 |
-
result = final_env.step(AuditronAction(agent_id="buyer", content=buy_action))
|
| 778 |
-
try:
|
| 779 |
-
buy_parsed = json.loads(buy_action)
|
| 780 |
-
if not isinstance(buy_parsed, dict): buy_parsed = {}
|
| 781 |
-
except Exception:
|
| 782 |
-
buy_parsed = {}
|
| 783 |
-
|
| 784 |
-
resolution = result.observation.get("resolution", {})
|
| 785 |
-
winner = buy_parsed.get("pick")
|
| 786 |
-
failed = resolution.get("failed", False)
|
| 787 |
-
penalty = resolution.get("penalty", 0.0)
|
| 788 |
-
winner_bid = final_env.state.supplier_bids.get(winner, {}).get("bid_price", 0) if winner else 0
|
| 789 |
-
cumulative_spend += (winner_bid or 0) + (penalty or 0)
|
| 790 |
-
if failed:
|
| 791 |
-
cumulative_failures += 1
|
| 792 |
-
|
| 793 |
-
# Per-supplier data for this round
|
| 794 |
-
per_supplier = {}
|
| 795 |
-
for sid in SUPPLIER_IDS:
|
| 796 |
-
bid_info = final_env.state.supplier_bids.get(sid, {})
|
| 797 |
-
bid_price = bid_info.get("bid_price", 0) or 0
|
| 798 |
-
actual_str = bid_info.get("actual_strength", 0)
|
| 799 |
-
req_str = final_env.state.required_strength
|
| 800 |
-
cost = sup_obs_list[SUPPLIER_IDS.index(sid)].get("your_cost_per_point", 0)
|
| 801 |
-
production_cost = req_str * cost
|
| 802 |
-
round_profit = (bid_price - production_cost) if sid == winner and not failed else 0.0
|
| 803 |
-
cumulative_profits[sid] += round_profit
|
| 804 |
-
per_supplier[sid] = {
|
| 805 |
-
"personality": personalities[sid],
|
| 806 |
-
"bid_price": bid_price,
|
| 807 |
-
"actual_strength": actual_str,
|
| 808 |
-
"cheating": actual_str < req_str if actual_str else False,
|
| 809 |
-
"won": sid == winner,
|
| 810 |
-
"round_profit": round_profit,
|
| 811 |
-
"cumulative_profit": cumulative_profits[sid],
|
| 812 |
-
}
|
| 813 |
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
"
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
"
|
| 829 |
-
"
|
| 830 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 831 |
|
| 832 |
-
|
| 833 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 834 |
_log_episode({
|
| 835 |
"type": "final_summary",
|
| 836 |
"personalities": personalities,
|
| 837 |
-
"total_spend":
|
| 838 |
-
"total_failures":
|
| 839 |
-
"auditor_tpr":
|
| 840 |
-
"auditor_fpr":
|
| 841 |
-
"supplier_profits":
|
| 842 |
-
"supplier_ranking":
|
| 843 |
-
"final_rewards":
|
| 844 |
})
|
| 845 |
-
print(f"Final eval done. Spend={cumulative_spend:.1f} Failures={cumulative_failures}")
|
| 846 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
|
| 848 |
|
| 849 |
if __name__ == "__main__":
|
|
|
|
| 25 |
from trl import GRPOConfig, GRPOTrainer
|
| 26 |
|
| 27 |
# ββ Config (all tunables in one place) ββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/Qwen2.5-7B-Instruct")
|
| 29 |
MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "4096"))
|
| 30 |
LORA_RANK = int(os.environ.get("LORA_RANK", "16"))
|
| 31 |
NUM_TRAINING_STEPS = int(os.environ.get("NUM_TRAINING_STEPS", "500"))
|
|
|
|
| 702 |
output_dir=OUTPUT_DIR,
|
| 703 |
)
|
| 704 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
trainer = GRPOTrainer(
|
| 706 |
model=model,
|
| 707 |
processing_class=tokenizer,
|
| 708 |
reward_funcs=[format_reward, reasoning_reward, economic_reward],
|
| 709 |
args=training_args,
|
| 710 |
train_dataset=dataset,
|
|
|
|
| 711 |
)
|
| 712 |
|
| 713 |
print(f"\nStarting GRPO training ({NUM_TRAINING_STEPS} steps)...")
|
| 714 |
print(f"Logs: reasoning={REASONING_LOG} episodes={EPISODE_LOG} eval={EVAL_LOG}")
|
|
|
|
| 715 |
trainer.train()
|
| 716 |
print("Training complete!")
|
| 717 |
|
| 718 |
+
# 4. Final eval β run BEFORE saving (save_pretrained_merged modifies model internals)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
print("\n[Final eval β full 50-round episode]")
|
| 720 |
+
try:
|
| 721 |
+
final_env = AuditronEnv()
|
| 722 |
+
final_env.reset(seed=99999)
|
| 723 |
+
s = final_env.state
|
| 724 |
+
personalities = {sid: s.supplier_personalities[sid]["name"] for sid in SUPPLIER_IDS}
|
| 725 |
+
print(f"Personalities: {personalities}")
|
| 726 |
+
cumulative_spend = 0.0
|
| 727 |
+
cumulative_failures = 0
|
| 728 |
+
cumulative_profits = {sid: 0.0 for sid in SUPPLIER_IDS}
|
| 729 |
+
|
| 730 |
+
for rnd in range(TOTAL_PARTS):
|
| 731 |
+
# Suppliers
|
| 732 |
+
sup_obs_list = [final_env.get_supplier_obs(sid) for sid in SUPPLIER_IDS]
|
| 733 |
+
sup_prompts = [build_prompt(obs, "supplier") for obs in sup_obs_list]
|
| 734 |
+
sup_actions = generate_actions_batch(model, tokenizer, sup_prompts, max_new_tokens=64)
|
| 735 |
+
for sid, obs, action_str in zip(SUPPLIER_IDS, sup_obs_list, sup_actions):
|
| 736 |
+
result = final_env.step(AuditronAction(agent_id=sid, content=action_str))
|
| 737 |
+
if result.phase == "error":
|
| 738 |
+
req = obs["required_strength"]
|
| 739 |
+
cost = obs["your_cost_per_point"]
|
| 740 |
+
final_env.step(AuditronAction(agent_id=sid, content=json.dumps(
|
| 741 |
+
{"bid_price": round(req * cost * 1.1, 1), "actual_strength": req})))
|
|
|
|
| 742 |
|
| 743 |
+
# Auditor
|
| 744 |
+
aud_obs = final_env.get_auditor_obs()
|
| 745 |
+
aud_action = generate_action(model, tokenizer, build_prompt(aud_obs, "auditor"))
|
| 746 |
+
final_env.step(AuditronAction(agent_id="auditor", content=aud_action))
|
| 747 |
+
try:
|
| 748 |
+
aud_parsed = json.loads(aud_action)
|
| 749 |
+
if not isinstance(aud_parsed, dict): aud_parsed = {}
|
| 750 |
+
except Exception:
|
| 751 |
+
aud_parsed = {}
|
| 752 |
|
| 753 |
+
# Capture bids before buyer step resets supplier_bids
|
| 754 |
+
captured_bids = {sid: dict(final_env.state.supplier_bids.get(sid, {})) for sid in SUPPLIER_IDS}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
|
| 756 |
+
# Buyer
|
| 757 |
+
buy_obs = final_env.get_buyer_obs()
|
| 758 |
+
buy_action = generate_action(model, tokenizer, build_prompt(buy_obs, "buyer"))
|
| 759 |
+
result = final_env.step(AuditronAction(agent_id="buyer", content=buy_action))
|
| 760 |
+
if result.phase == "error":
|
| 761 |
+
fallback_pick = aud_parsed.get("pick") or SUPPLIER_IDS[0]
|
| 762 |
+
result = final_env.step(AuditronAction(agent_id="buyer", content=json.dumps(
|
| 763 |
+
{"pick": fallback_pick, "reason": "fallback"})))
|
| 764 |
+
try:
|
| 765 |
+
buy_parsed = json.loads(buy_action)
|
| 766 |
+
if not isinstance(buy_parsed, dict): buy_parsed = {}
|
| 767 |
+
except Exception:
|
| 768 |
+
buy_parsed = {}
|
| 769 |
+
|
| 770 |
+
resolution = result.observation.get("resolution", {})
|
| 771 |
+
winner = buy_parsed.get("pick")
|
| 772 |
+
failed = resolution.get("failed", False)
|
| 773 |
+
penalty = resolution.get("penalty", 0.0)
|
| 774 |
+
winner_bid = captured_bids.get(winner, {}).get("bid_price", 0) if winner else 0
|
| 775 |
+
cumulative_spend += (winner_bid or 0) + (penalty or 0)
|
| 776 |
+
if failed:
|
| 777 |
+
cumulative_failures += 1
|
| 778 |
+
|
| 779 |
+
# Per-supplier data for this round
|
| 780 |
+
per_supplier = {}
|
| 781 |
+
for sid in SUPPLIER_IDS:
|
| 782 |
+
bid_info = captured_bids.get(sid, {})
|
| 783 |
+
bid_price = bid_info.get("bid_price", 0) or 0
|
| 784 |
+
actual_str = bid_info.get("actual_strength", 0)
|
| 785 |
+
req_str = final_env.state.required_strength
|
| 786 |
+
cost = sup_obs_list[SUPPLIER_IDS.index(sid)].get("your_cost_per_point", 0)
|
| 787 |
+
production_cost = req_str * cost
|
| 788 |
+
round_profit = (bid_price - production_cost) if sid == winner and not failed else 0.0
|
| 789 |
+
cumulative_profits[sid] += round_profit
|
| 790 |
+
per_supplier[sid] = {
|
| 791 |
+
"personality": personalities[sid],
|
| 792 |
+
"bid_price": bid_price,
|
| 793 |
+
"actual_strength": actual_str,
|
| 794 |
+
"cheating": actual_str < req_str if actual_str else False,
|
| 795 |
+
"won": sid == winner,
|
| 796 |
+
"round_profit": round_profit,
|
| 797 |
+
"cumulative_profit": cumulative_profits[sid],
|
| 798 |
+
}
|
| 799 |
+
|
| 800 |
+
_log_episode({
|
| 801 |
+
"type": "final_round",
|
| 802 |
+
"round": rnd + 1,
|
| 803 |
+
"required_strength": final_env.state.required_strength,
|
| 804 |
+
"personalities": personalities,
|
| 805 |
+
"auditor_pick": aud_parsed.get("pick"),
|
| 806 |
+
"auditor_flags": aud_parsed.get("flags", []),
|
| 807 |
+
"auditor_reason": aud_parsed.get("reason", ""),
|
| 808 |
+
"buyer_pick": winner,
|
| 809 |
+
"buyer_followed_auditor": (winner == aud_parsed.get("pick")) if winner else None,
|
| 810 |
+
"part_failed": failed,
|
| 811 |
+
"failure_penalty": penalty or 0.0,
|
| 812 |
+
"round_spend": (winner_bid or 0) + (penalty or 0),
|
| 813 |
+
"cumulative_spend": cumulative_spend,
|
| 814 |
+
"cumulative_failures": cumulative_failures,
|
| 815 |
+
"per_supplier": per_supplier,
|
| 816 |
+
})
|
| 817 |
|
| 818 |
+
if result.done:
|
| 819 |
+
summary = result.observation.get("episode_summary", {})
|
| 820 |
+
_log_episode({
|
| 821 |
+
"type": "final_summary",
|
| 822 |
+
"personalities": personalities,
|
| 823 |
+
"total_spend": summary.get("buyer_total_spend", cumulative_spend),
|
| 824 |
+
"total_failures": summary.get("num_failures", cumulative_failures),
|
| 825 |
+
"auditor_tpr": summary.get("auditor_tpr"),
|
| 826 |
+
"auditor_fpr": summary.get("auditor_fpr"),
|
| 827 |
+
"supplier_profits": summary.get("supplier_profits", {}),
|
| 828 |
+
"supplier_ranking": summary.get("supplier_ranking", []),
|
| 829 |
+
"final_rewards": summary.get("final_rewards", {}),
|
| 830 |
+
})
|
| 831 |
+
print(f"Final eval done. Spend={cumulative_spend:.1f} Failures={cumulative_failures}")
|
| 832 |
+
break
|
| 833 |
+
else:
|
| 834 |
+
# Loop ended without done β log summary from accumulated data
|
| 835 |
_log_episode({
|
| 836 |
"type": "final_summary",
|
| 837 |
"personalities": personalities,
|
| 838 |
+
"total_spend": cumulative_spend,
|
| 839 |
+
"total_failures": cumulative_failures,
|
| 840 |
+
"auditor_tpr": None,
|
| 841 |
+
"auditor_fpr": None,
|
| 842 |
+
"supplier_profits": {sid: cumulative_profits[sid] for sid in SUPPLIER_IDS},
|
| 843 |
+
"supplier_ranking": sorted(SUPPLIER_IDS, key=lambda s: cumulative_profits[s], reverse=True),
|
| 844 |
+
"final_rewards": {},
|
| 845 |
})
|
| 846 |
+
print(f"Final eval done (fallback summary). Spend={cumulative_spend:.1f} Failures={cumulative_failures}")
|
| 847 |
+
except Exception as e:
|
| 848 |
+
print(f"Final eval failed: {e}")
|
| 849 |
+
|
| 850 |
+
# 5. Save
|
| 851 |
+
model.save_pretrained_merged(OUTPUT_DIR, tokenizer, save_method="merged_16bit")
|
| 852 |
+
print(f"Model saved to {OUTPUT_DIR}/")
|
| 853 |
|
| 854 |
|
| 855 |
if __name__ == "__main__":
|