shapiron Claude Sonnet 4.6 commited on
Commit
0feab3a
Β·
1 Parent(s): 3b7ac2d

Update auditron env server and training script

Browse files

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. server.py +6 -0
  2. train.py +127 -121
server.py CHANGED
@@ -304,6 +304,12 @@ class AuditronEnv(Environment[AuditronAction, AuditronObservation, AuditronState
304
  'Example: {"bid_price": 85, "actual_strength": 75}',
305
  )
306
 
 
 
 
 
 
 
307
  s.supplier_bids[agent_id] = {
308
  "bid_price": bid_price,
309
  "actual_strength": actual_strength,
 
304
  'Example: {"bid_price": 85, "actual_strength": 75}',
305
  )
306
 
307
+ # Clamp bid to at least production cost β€” penalize if model bid below it
308
+ production_cost = s.required_strength * s.supplier_costs.get(agent_id, 0)
309
+ if bid_price < production_cost:
310
+ s.supplier_rewards[agent_id] += PENALTY_INVALID_FORMAT / 2 # -2.5
311
+ bid_price = production_cost # clamp so game math stays valid
312
+
313
  s.supplier_bids[agent_id] = {
314
  "bid_price": bid_price,
315
  "actual_strength": actual_strength,
train.py CHANGED
@@ -25,7 +25,7 @@ from datetime import datetime
25
  from trl import GRPOConfig, GRPOTrainer
26
 
27
  # ── Config (all tunables in one place) ────────────────────────────────────────
28
- MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/Qwen2.5-1.5B-Instruct")
29
  MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "4096"))
30
  LORA_RANK = int(os.environ.get("LORA_RANK", "16"))
31
  NUM_TRAINING_STEPS = int(os.environ.get("NUM_TRAINING_STEPS", "500"))
@@ -702,148 +702,154 @@ def main():
702
  output_dir=OUTPUT_DIR,
703
  )
704
 
705
- # Periodic checkpoint callback β€” pauses every CHECKPOINT_EVERY steps to run
706
- # a 1-episode eval and log per-personality profits for line charts in reports.
707
- from transformers import TrainerCallback
708
-
709
- class CheckpointEvalCallback(TrainerCallback):
710
- def on_step_end(self, args, state, control, **kwargs):
711
- if state.global_step % CHECKPOINT_EVERY == 0 and state.global_step > 0:
712
- print(f"\n[Checkpoint eval at step {state.global_step}/{NUM_TRAINING_STEPS}]")
713
- evaluate_model(model, tokenizer, num_episodes=1, eval_step=state.global_step, max_rounds=CHECKPOINT_ROUNDS)
714
- # Switch back to training mode after eval
715
- from unsloth import FastLanguageModel
716
- FastLanguageModel.for_training(model)
717
- tokenizer.padding_side = "left"
718
-
719
  trainer = GRPOTrainer(
720
  model=model,
721
  processing_class=tokenizer,
722
  reward_funcs=[format_reward, reasoning_reward, economic_reward],
723
  args=training_args,
724
  train_dataset=dataset,
725
- callbacks=[CheckpointEvalCallback()],
726
  )
727
 
728
  print(f"\nStarting GRPO training ({NUM_TRAINING_STEPS} steps)...")
729
  print(f"Logs: reasoning={REASONING_LOG} episodes={EPISODE_LOG} eval={EVAL_LOG}")
730
- print(f"Checkpoint evals every {CHECKPOINT_EVERY} steps.")
731
  trainer.train()
732
  print("Training complete!")
733
 
734
- # 4. Save
735
- model.save_pretrained_merged(OUTPUT_DIR, tokenizer, save_method="merged_16bit")
736
- print(f"Model saved to {OUTPUT_DIR}/")
737
-
738
- # 5. Final eval β€” 1 full 50-round episode, rich per-round logging for charts
739
  print("\n[Final eval β€” full 50-round episode]")
740
- from unsloth import FastLanguageModel
741
- FastLanguageModel.for_inference(model)
742
- final_env = AuditronEnv()
743
- final_env.reset(seed=99999)
744
- s = final_env.state
745
- personalities = {sid: s.supplier_personalities[sid]["name"] for sid in SUPPLIER_IDS}
746
- print(f"Personalities: {personalities}")
747
- cumulative_spend = 0.0
748
- cumulative_failures = 0
749
- cumulative_profits = {sid: 0.0 for sid in SUPPLIER_IDS}
750
-
751
- for rnd in range(TOTAL_PARTS):
752
- # Suppliers
753
- sup_obs_list = [final_env.get_supplier_obs(sid) for sid in SUPPLIER_IDS]
754
- sup_prompts = [build_prompt(obs, "supplier") for obs in sup_obs_list]
755
- sup_actions = generate_actions_batch(model, tokenizer, sup_prompts, max_new_tokens=64)
756
- for sid, obs, action_str in zip(SUPPLIER_IDS, sup_obs_list, sup_actions):
757
- result = final_env.step(AuditronAction(agent_id=sid, content=action_str))
758
- if result.phase == "error":
759
- req = obs["required_strength"]
760
- cost = obs["your_cost_per_point"]
761
- final_env.step(AuditronAction(agent_id=sid, content=json.dumps(
762
- {"bid_price": round(req * cost * 1.1, 1), "actual_strength": req})))
763
 
764
- # Auditor
765
- aud_obs = final_env.get_auditor_obs()
766
- aud_action = generate_action(model, tokenizer, build_prompt(aud_obs, "auditor"))
767
- final_env.step(AuditronAction(agent_id="auditor", content=aud_action))
768
- try:
769
- aud_parsed = json.loads(aud_action)
770
- if not isinstance(aud_parsed, dict): aud_parsed = {}
771
- except Exception:
772
- aud_parsed = {}
773
 
774
- # Buyer
775
- buy_obs = final_env.get_buyer_obs()
776
- buy_action = generate_action(model, tokenizer, build_prompt(buy_obs, "buyer"))
777
- result = final_env.step(AuditronAction(agent_id="buyer", content=buy_action))
778
- try:
779
- buy_parsed = json.loads(buy_action)
780
- if not isinstance(buy_parsed, dict): buy_parsed = {}
781
- except Exception:
782
- buy_parsed = {}
783
-
784
- resolution = result.observation.get("resolution", {})
785
- winner = buy_parsed.get("pick")
786
- failed = resolution.get("failed", False)
787
- penalty = resolution.get("penalty", 0.0)
788
- winner_bid = final_env.state.supplier_bids.get(winner, {}).get("bid_price", 0) if winner else 0
789
- cumulative_spend += (winner_bid or 0) + (penalty or 0)
790
- if failed:
791
- cumulative_failures += 1
792
-
793
- # Per-supplier data for this round
794
- per_supplier = {}
795
- for sid in SUPPLIER_IDS:
796
- bid_info = final_env.state.supplier_bids.get(sid, {})
797
- bid_price = bid_info.get("bid_price", 0) or 0
798
- actual_str = bid_info.get("actual_strength", 0)
799
- req_str = final_env.state.required_strength
800
- cost = sup_obs_list[SUPPLIER_IDS.index(sid)].get("your_cost_per_point", 0)
801
- production_cost = req_str * cost
802
- round_profit = (bid_price - production_cost) if sid == winner and not failed else 0.0
803
- cumulative_profits[sid] += round_profit
804
- per_supplier[sid] = {
805
- "personality": personalities[sid],
806
- "bid_price": bid_price,
807
- "actual_strength": actual_str,
808
- "cheating": actual_str < req_str if actual_str else False,
809
- "won": sid == winner,
810
- "round_profit": round_profit,
811
- "cumulative_profit": cumulative_profits[sid],
812
- }
813
 
814
- _log_episode({
815
- "type": "final_round",
816
- "round": rnd + 1,
817
- "required_strength": final_env.state.required_strength,
818
- "personalities": personalities,
819
- "auditor_pick": aud_parsed.get("pick"),
820
- "auditor_flags": aud_parsed.get("flags", []),
821
- "auditor_reason": aud_parsed.get("reason", ""),
822
- "buyer_pick": winner,
823
- "buyer_followed_auditor": (winner == aud_parsed.get("pick")) if winner else None,
824
- "part_failed": failed,
825
- "failure_penalty": penalty or 0.0,
826
- "round_spend": (winner_bid or 0) + (penalty or 0),
827
- "cumulative_spend": cumulative_spend,
828
- "cumulative_failures": cumulative_failures,
829
- "per_supplier": per_supplier,
830
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831
 
832
- if result.done:
833
- summary = result.observation.get("episode_summary", {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
834
  _log_episode({
835
  "type": "final_summary",
836
  "personalities": personalities,
837
- "total_spend": summary.get("buyer_total_spend", cumulative_spend),
838
- "total_failures": summary.get("num_failures", cumulative_failures),
839
- "auditor_tpr": summary.get("auditor_tpr"),
840
- "auditor_fpr": summary.get("auditor_fpr"),
841
- "supplier_profits": summary.get("supplier_profits", {}),
842
- "supplier_ranking": summary.get("supplier_ranking", []),
843
- "final_rewards": summary.get("final_rewards", {}),
844
  })
845
- print(f"Final eval done. Spend={cumulative_spend:.1f} Failures={cumulative_failures}")
846
- break
 
 
 
 
 
847
 
848
 
849
  if __name__ == "__main__":
 
25
  from trl import GRPOConfig, GRPOTrainer
26
 
27
  # ── Config (all tunables in one place) ────────────────────────────────────────
28
+ MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/Qwen2.5-7B-Instruct")
29
  MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "4096"))
30
  LORA_RANK = int(os.environ.get("LORA_RANK", "16"))
31
  NUM_TRAINING_STEPS = int(os.environ.get("NUM_TRAINING_STEPS", "500"))
 
702
  output_dir=OUTPUT_DIR,
703
  )
704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
  trainer = GRPOTrainer(
706
  model=model,
707
  processing_class=tokenizer,
708
  reward_funcs=[format_reward, reasoning_reward, economic_reward],
709
  args=training_args,
710
  train_dataset=dataset,
 
711
  )
712
 
713
  print(f"\nStarting GRPO training ({NUM_TRAINING_STEPS} steps)...")
714
  print(f"Logs: reasoning={REASONING_LOG} episodes={EPISODE_LOG} eval={EVAL_LOG}")
 
715
  trainer.train()
716
  print("Training complete!")
717
 
718
+ # 4. Final eval β€” run BEFORE saving (save_pretrained_merged modifies model internals)
 
 
 
 
719
  print("\n[Final eval β€” full 50-round episode]")
720
+ try:
721
+ final_env = AuditronEnv()
722
+ final_env.reset(seed=99999)
723
+ s = final_env.state
724
+ personalities = {sid: s.supplier_personalities[sid]["name"] for sid in SUPPLIER_IDS}
725
+ print(f"Personalities: {personalities}")
726
+ cumulative_spend = 0.0
727
+ cumulative_failures = 0
728
+ cumulative_profits = {sid: 0.0 for sid in SUPPLIER_IDS}
729
+
730
+ for rnd in range(TOTAL_PARTS):
731
+ # Suppliers
732
+ sup_obs_list = [final_env.get_supplier_obs(sid) for sid in SUPPLIER_IDS]
733
+ sup_prompts = [build_prompt(obs, "supplier") for obs in sup_obs_list]
734
+ sup_actions = generate_actions_batch(model, tokenizer, sup_prompts, max_new_tokens=64)
735
+ for sid, obs, action_str in zip(SUPPLIER_IDS, sup_obs_list, sup_actions):
736
+ result = final_env.step(AuditronAction(agent_id=sid, content=action_str))
737
+ if result.phase == "error":
738
+ req = obs["required_strength"]
739
+ cost = obs["your_cost_per_point"]
740
+ final_env.step(AuditronAction(agent_id=sid, content=json.dumps(
741
+ {"bid_price": round(req * cost * 1.1, 1), "actual_strength": req})))
 
742
 
743
+ # Auditor
744
+ aud_obs = final_env.get_auditor_obs()
745
+ aud_action = generate_action(model, tokenizer, build_prompt(aud_obs, "auditor"))
746
+ final_env.step(AuditronAction(agent_id="auditor", content=aud_action))
747
+ try:
748
+ aud_parsed = json.loads(aud_action)
749
+ if not isinstance(aud_parsed, dict): aud_parsed = {}
750
+ except Exception:
751
+ aud_parsed = {}
752
 
753
+ # Capture bids before buyer step resets supplier_bids
754
+ captured_bids = {sid: dict(final_env.state.supplier_bids.get(sid, {})) for sid in SUPPLIER_IDS}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755
 
756
+ # Buyer
757
+ buy_obs = final_env.get_buyer_obs()
758
+ buy_action = generate_action(model, tokenizer, build_prompt(buy_obs, "buyer"))
759
+ result = final_env.step(AuditronAction(agent_id="buyer", content=buy_action))
760
+ if result.phase == "error":
761
+ fallback_pick = aud_parsed.get("pick") or SUPPLIER_IDS[0]
762
+ result = final_env.step(AuditronAction(agent_id="buyer", content=json.dumps(
763
+ {"pick": fallback_pick, "reason": "fallback"})))
764
+ try:
765
+ buy_parsed = json.loads(buy_action)
766
+ if not isinstance(buy_parsed, dict): buy_parsed = {}
767
+ except Exception:
768
+ buy_parsed = {}
769
+
770
+ resolution = result.observation.get("resolution", {})
771
+ winner = buy_parsed.get("pick")
772
+ failed = resolution.get("failed", False)
773
+ penalty = resolution.get("penalty", 0.0)
774
+ winner_bid = captured_bids.get(winner, {}).get("bid_price", 0) if winner else 0
775
+ cumulative_spend += (winner_bid or 0) + (penalty or 0)
776
+ if failed:
777
+ cumulative_failures += 1
778
+
779
+ # Per-supplier data for this round
780
+ per_supplier = {}
781
+ for sid in SUPPLIER_IDS:
782
+ bid_info = captured_bids.get(sid, {})
783
+ bid_price = bid_info.get("bid_price", 0) or 0
784
+ actual_str = bid_info.get("actual_strength", 0)
785
+ req_str = final_env.state.required_strength
786
+ cost = sup_obs_list[SUPPLIER_IDS.index(sid)].get("your_cost_per_point", 0)
787
+ production_cost = req_str * cost
788
+ round_profit = (bid_price - production_cost) if sid == winner and not failed else 0.0
789
+ cumulative_profits[sid] += round_profit
790
+ per_supplier[sid] = {
791
+ "personality": personalities[sid],
792
+ "bid_price": bid_price,
793
+ "actual_strength": actual_str,
794
+ "cheating": actual_str < req_str if actual_str else False,
795
+ "won": sid == winner,
796
+ "round_profit": round_profit,
797
+ "cumulative_profit": cumulative_profits[sid],
798
+ }
799
+
800
+ _log_episode({
801
+ "type": "final_round",
802
+ "round": rnd + 1,
803
+ "required_strength": final_env.state.required_strength,
804
+ "personalities": personalities,
805
+ "auditor_pick": aud_parsed.get("pick"),
806
+ "auditor_flags": aud_parsed.get("flags", []),
807
+ "auditor_reason": aud_parsed.get("reason", ""),
808
+ "buyer_pick": winner,
809
+ "buyer_followed_auditor": (winner == aud_parsed.get("pick")) if winner else None,
810
+ "part_failed": failed,
811
+ "failure_penalty": penalty or 0.0,
812
+ "round_spend": (winner_bid or 0) + (penalty or 0),
813
+ "cumulative_spend": cumulative_spend,
814
+ "cumulative_failures": cumulative_failures,
815
+ "per_supplier": per_supplier,
816
+ })
817
 
818
+ if result.done:
819
+ summary = result.observation.get("episode_summary", {})
820
+ _log_episode({
821
+ "type": "final_summary",
822
+ "personalities": personalities,
823
+ "total_spend": summary.get("buyer_total_spend", cumulative_spend),
824
+ "total_failures": summary.get("num_failures", cumulative_failures),
825
+ "auditor_tpr": summary.get("auditor_tpr"),
826
+ "auditor_fpr": summary.get("auditor_fpr"),
827
+ "supplier_profits": summary.get("supplier_profits", {}),
828
+ "supplier_ranking": summary.get("supplier_ranking", []),
829
+ "final_rewards": summary.get("final_rewards", {}),
830
+ })
831
+ print(f"Final eval done. Spend={cumulative_spend:.1f} Failures={cumulative_failures}")
832
+ break
833
+ else:
834
+ # Loop ended without done β€” log summary from accumulated data
835
  _log_episode({
836
  "type": "final_summary",
837
  "personalities": personalities,
838
+ "total_spend": cumulative_spend,
839
+ "total_failures": cumulative_failures,
840
+ "auditor_tpr": None,
841
+ "auditor_fpr": None,
842
+ "supplier_profits": {sid: cumulative_profits[sid] for sid in SUPPLIER_IDS},
843
+ "supplier_ranking": sorted(SUPPLIER_IDS, key=lambda s: cumulative_profits[s], reverse=True),
844
+ "final_rewards": {},
845
  })
846
+ print(f"Final eval done (fallback summary). Spend={cumulative_spend:.1f} Failures={cumulative_failures}")
847
+ except Exception as e:
848
+ print(f"Final eval failed: {e}")
849
+
850
+ # 5. Save
851
+ model.save_pretrained_merged(OUTPUT_DIR, tokenizer, save_method="merged_16bit")
852
+ print(f"Model saved to {OUTPUT_DIR}/")
853
 
854
 
855
  if __name__ == "__main__":