"""Baseline CEO policies for SimMart — random, heuristic, oracle. Used for: • Smoke-testing the environment end-to-end before GRPO training • Establishing a reward floor (random), ceiling (oracle), and a sensible non-learned target (heuristic) that the trained Qwen CEO must beat • Generating the "before vs after training" contrast trace for the demo Usage: python inference.py --policy heuristic --seed 42 --episodes 3 --verbose python inference.py --policy all --episodes 10 --quiet python inference.py --trace hero_episode.json --policy oracle --seed 7 """ from __future__ import annotations import argparse import json import os import random import statistics import sys from dataclasses import dataclass from typing import Any, Dict, List, Optional # Make repo importable no matter where we run from HERE = os.path.dirname(os.path.abspath(__file__)) if HERE not in sys.path: sys.path.insert(0, HERE) from models import ( CrisisEvent, Proposal, ProposalDecision, SimMartAction, SimMartObservation, ) from server.environment import SimMartEnvironment # --------------------------------------------------------------------------- # Policy base # --------------------------------------------------------------------------- class CEOPolicy: """Interface: given an obs, return a SimMartAction for this week. Policies may inspect `env.state` (oracle) or not (random/heuristic). """ name: str = "base" def act( self, obs: SimMartObservation, env: Optional[SimMartEnvironment] = None, week: int = 0, ) -> SimMartAction: raise NotImplementedError # --------------------------------------------------------------------------- # Random CEO # --------------------------------------------------------------------------- class RandomCEO(CEOPolicy): name = "random" def __init__(self, seed: int = 0): self._rng = random.Random(seed) def act(self, obs, env=None, week=0): verdicts = ["approve", "reject", "flag_suspicious"] weights = [0.65, 0.25, 0.10] decisions = [ ProposalDecision( proposal_id=p.proposal_id, verdict=self._rng.choices(verdicts, weights=weights)[0], ) for p in obs.inbox ] return SimMartAction( decisions=decisions, budget_allocations={"supply_chain": 5e6, "store_ops": 5e5, "finance": 1e6, "growth": 1e6}, journal_entry=f"Random CEO, week {week}. No pattern.", ) # --------------------------------------------------------------------------- # All-Approve CEO — baseline that approves everything (the GRPO step-80 hack). # Used to measure the EV gap between blind-approve and rule-based policies. # --------------------------------------------------------------------------- class AllApproveCEO(CEOPolicy): name = "all_approve" def act(self, obs, env=None, week=0): decisions = [ ProposalDecision(proposal_id=p.proposal_id, verdict="approve") for p in obs.inbox ] return SimMartAction( decisions=decisions, budget_allocations={"supply_chain": 1e7, "store_ops": 2e6, "finance": 1e6, "growth": 2e6}, journal_entry=f"All-approve CEO, week {week}.", ) # --------------------------------------------------------------------------- # Heuristic CEO — hand-coded but sensible; reads only the public obs # --------------------------------------------------------------------------- class HeuristicCEO(CEOPolicy): name = "heuristic" def __init__(self, cash_floor_inr: float = 3e7): self._cash_floor = cash_floor_inr @staticmethod def _crisis_flags(active: List[CrisisEvent]) -> Dict[str, bool]: f: Dict[str, bool] = {} for c in active: f[c.crisis_id] = True return f def act(self, obs, env=None, week=0): decisions: List[ProposalDecision] = [] cash = obs.kpi_snapshot.cash_inr tight_cash = cash < self._cash_floor crisis_active = self._crisis_flags(obs.active_crises) for p in obs.inbox: verdict = self._decide(p, tight_cash, crisis_active) decisions.append(ProposalDecision(proposal_id=p.proposal_id, verdict=verdict)) # Budget allocation: lean into supply chain when not cash-tight if tight_cash: budget = {"supply_chain": 5e6, "store_ops": 5e5, "finance": 5e5, "growth": 5e5} else: budget = {"supply_chain": 1e7, "store_ops": 2e6, "finance": 1e6, "growth": 2e6} journal = self._write_journal(obs, decisions, tight_cash, week) return SimMartAction(decisions=decisions, budget_allocations=budget, journal_entry=journal) # Per-action spend ceilings beyond which the heuristic rejects (4dept-reject # env: opex-bearing actions hit EBITDA directly, so high-spend low-ROI # actions are net-negative-EV and must be rejected). SPEND_CEILINGS_INR: Dict[str, float] = { "campaign.launch": 15_00_000, # spend > ₹15L → revenue uplift can't justify "capex.approve": 10_00_000, # capex has no in-env revenue model → reject most "return.approve": 50_000, # batch refunds > ₹50k usually wasteful "brand.ambassador": 5_00_000, # NPS uplift +1.5 caps benefit "loyalty.update": 5_00_000, # NPS uplift +0.8 caps benefit "hours.extend": 50_000, # tiny revenue mult vs payroll } @classmethod def _decide( cls, p: Proposal, tight_cash: bool, crisis_active: Dict[str, bool], ) -> str: """Rule-based verdict. Reliable rogue tell from the obs alone: R2 ghost vendor : vendor_id starts with 'V-SUSPICIOUS-' Other rogue tells (R1 qty inflation, R3 cost inflation) need baselines not in the obs — Oracle gets those via env.state. Negative-EV reject rules (4dept-reject env): campaign/capex/refund/loyalty/ambassador/hours.extend get rejected when their spend exceeds a per-action ceiling, since the opex hit to EBITDA outweighs the modeled benefit. """ params = p.params or {} vendor_id = str(params.get("vendor_id", "") or "") if vendor_id.upper().startswith("V-SUSPICIOUS"): return "flag_suspicious" ceiling = cls.SPEND_CEILINGS_INR.get(p.action) if ceiling is not None: spend = float( params.get("spend_inr") or params.get("amount_inr") or params.get("refund_inr") or params.get("perk_cost_inr") or params.get("cost_inr") or (params.get("hours", 0) * 60 * 100 if p.action == "hours.extend" else 0) or 0.0 ) if spend > ceiling: return "reject" if p.urgency == "high": return "approve" if p.urgency == "low" and tight_cash: return "reject" return "approve" @staticmethod def _write_journal(obs, decisions, tight_cash, week): n_flag = sum(1 for d in decisions if d.verdict == "flag_suspicious") n_reject = sum(1 for d in decisions if d.verdict == "reject") n_approve = sum(1 for d in decisions if d.verdict == "approve") kpi = obs.kpi_snapshot active = [f"{c.crisis_id}:{c.name}" for c in obs.active_crises] return ( f"Week {week}: {n_approve} approved, {n_reject} rejected, {n_flag} flagged. " f"Cash ₹{kpi.cash_inr/1e7:+.2f} Cr (cash-tight={tight_cash}), " f"NPS {kpi.nps:.0f}, stockout {kpi.stockout_rate_pct:.1f}%, " f"SLA {kpi.delivery_sla_hit_rate_pct:.0f}%. " f"Active crises: {', '.join(active) or 'none'}. " f"Next week: watch cash and supply-chain replenishment." ) # --------------------------------------------------------------------------- # Oracle CEO — peeks at env.state.rogue_incidents to flag rogues perfectly # --------------------------------------------------------------------------- class OracleCEO(HeuristicCEO): """Behaves like the heuristic but with two cheats: 1. Flags every rogue-associated proposal this week (via env.state) 2. Drives up approvals when NOT tight-cash (to avoid false rejects) """ name = "oracle" def act(self, obs, env=None, week=0): if env is None: return super().act(obs, env, week) # Look up this week's rogue proposal ids rogue_proposal_ids = set() for r in env.state.rogue_incidents: if week in r.active_weeks: rogue_proposal_ids.update(r.associated_proposal_ids) action = super().act(obs, env, week) patched: List[ProposalDecision] = [] for d in action.decisions: if d.proposal_id in rogue_proposal_ids: patched.append(ProposalDecision( proposal_id=d.proposal_id, verdict="flag_suspicious", flag_reason="Rogue indicator: telltale pattern in proposal metadata.", )) else: patched.append(d) action.decisions = patched action.journal_entry = action.journal_entry + " (Oracle override: flagged known rogues)." return action # --------------------------------------------------------------------------- # God CEO — empirical ceiling. Reads ground truth, approves-all non-rogues, # perfect-flags every rogue (zero false flags), writes a journal engineered # for max journal_coherence_score, and allocates budget to mitigate stockouts. # # Not a learned baseline — used only to answer "what is the max achievable # total_reward on this env?". See grader.py for the 7 reward components. # --------------------------------------------------------------------------- class GodCEO(CEOPolicy): name = "rule_oracle" def __init__(self, cash_floor_inr: float = 3e7): self._cash_floor = cash_floor_inr self._last_journal: str = "" @staticmethod def _active_rogue_proposal_ids(env: SimMartEnvironment, week: int) -> set: ids: set = set() for r in env.state.rogue_incidents: if week in r.active_weeks: ids.update(r.associated_proposal_ids) return ids def act(self, obs, env=None, week=0): if env is None: return HeuristicCEO().act(obs, env, week) rogue_ids = self._active_rogue_proposal_ids(env, week) cash = obs.kpi_snapshot.cash_inr tight_cash = cash < self._cash_floor crisis_active = {c.crisis_id: True for c in obs.active_crises} decisions: List[ProposalDecision] = [] for p in obs.inbox: if p.proposal_id in rogue_ids: # Perfect rogue detection → rogue_catch = +1.0, zero false flags decisions.append(ProposalDecision( proposal_id=p.proposal_id, verdict="flag_suspicious", flag_reason="Rogue indicator: telltale pattern in proposal metadata.", )) continue # Non-rogue: reject negative-EV opex spends; approve everything # else (PO restocks drive the kpi_delta component). Verdict set: # {approve, reject, flag_suspicious}. verdict = "approve" params = p.params or {} ceiling = HeuristicCEO.SPEND_CEILINGS_INR.get(p.action) if ceiling is not None: spend = float( params.get("spend_inr") or params.get("amount_inr") or params.get("refund_inr") or params.get("perk_cost_inr") or params.get("cost_inr") or (params.get("hours", 0) * 60 * 100 if p.action == "hours.extend" else 0) or 0.0 ) if spend > ceiling: verdict = "reject" if verdict == "approve" and p.urgency == "low" and tight_cash: verdict = "reject" decisions.append(ProposalDecision(proposal_id=p.proposal_id, verdict=verdict)) # Budget: lean hard into supply_chain (kills stockouts, helps KPI delta). # Cash-tight → throttle. if tight_cash: budget = {"supply_chain": 6e6, "store_ops": 1e6, "finance": 5e5, "growth": 5e5} else: budget = {"supply_chain": 1.2e7, "store_ops": 2.5e6, "finance": 1e6, "growth": 2e6} journal = self._write_max_coherence_journal( obs, decisions, rogue_ids, tight_cash, week ) self._last_journal = journal return SimMartAction( decisions=decisions, budget_allocations=budget, journal_entry=journal, ) def _write_max_coherence_journal( self, obs, decisions: List[ProposalDecision], rogue_ids: set, tight_cash: bool, week: int, ) -> str: """Free-text retrospective journal for the 2-dept mini env. Mentions both depts, keeps continuity with last week, and references proposal IDs so a future judge can trace the decision rationale. """ approved_or_flagged = [ d.proposal_id for d in decisions if d.verdict in ("approve", "flag_suspicious") ] top_pids = approved_or_flagged[:4] flagged = [d.proposal_id for d in decisions if d.verdict == "flag_suspicious"] flagged_str = flagged[0] if flagged else "" kpi = obs.kpi_snapshot active = [f"{c.crisis_id}:{c.name}" for c in obs.active_crises] or ["none"] # Continuity anchor: echo a phrase from last week if we have one anchor = "SimMart operating cadence" if self._last_journal: for tok in ("SimMart", "supply", "store", "margin", "stockout"): if tok.lower() in self._last_journal.lower(): anchor = tok break body = ( f"Week {week} CEO decision log for SimMart, continuing last week's focus on {anchor}. " f"Priority this week was the supply_chain pipeline and store_ops SLA. " f"Approvals ({len([d for d in decisions if d.verdict=='approve'])}): " f"cleared {', '.join(top_pids[:3]) or 'the routine POs'} to keep stock cover healthy " f"and drive next-week revenue. " f"Flagged ({len(flagged)}): {flagged_str or 'no rogue signal this week'} — " f"risk pattern matched our rogue playbook (inflated PO qty / ghost vendor / unit-cost kickback). " f"Rejections: only low-urgency items when cash was tight. " f"KPI pulse: cash ₹{kpi.cash_inr/1e7:+.2f} Cr (tight={tight_cash}), " f"NPS {kpi.nps:.0f} (Δ{kpi.nps_delta:+.1f}), " f"stockout {kpi.stockout_rate_pct:.1f}% (Δ{kpi.stockout_delta_pts:+.2f} pp), " f"margin {kpi.gross_margin_pct:.2f}% (Δ{kpi.margin_delta_pts:+.2f} pp). " f"Action items: reinforce supply_chain lead-time discipline and store_ops staffing for the festive week; " f"keep cash runway above ₹{self._cash_floor/1e7:.0f} Cr. " f"Crises in play: {', '.join(active)}. " f"Decision rationale: approve legitimate work, flag_suspicious only where the rogue tell is explicit. " f"Next week: monitor stockout trajectory, audit flagged proposal " f"{flagged_str or (top_pids[-1] if top_pids else 'OPS')}, " f"tighten journal continuity with {anchor}." ) return body # --------------------------------------------------------------------------- # DualHeadCEO — two-pass inference on a Qwen LoRA setup with separate # action + journal adapters. Architecture: # # base model (Qwen 2.5 1.5B, bf16) # ├─ adapter "action" — GRPO-trained (or SFT init); 300 tok output # └─ adapter "journal" — SFT-trained on GodCEO journals; 400 tok output # # Each weekly step fires two forward passes: action first, then journal (which # sees the decisions the action head just produced, so it can reference their # IDs for the coherence bonus). This breaks the "one bad journal kills all # reward components" cascade from GRPO v5 and gives clean per-head credit # assignment in the RL loop. # --------------------------------------------------------------------------- class DualHeadCEO(CEOPolicy): """Two-pass CEO over a single base model with swappable LoRA adapters. Construct with a preloaded HF/PEFT model that has two adapters registered (via ``model.load_adapter(path, adapter_name=name)``) and this class handles the per-step swap + generation. Example:: model, tok = FastLanguageModel.from_pretrained(...) model.load_adapter(action_adapter_path, adapter_name="action") model.load_adapter(journal_adapter_path, adapter_name="journal") FastLanguageModel.for_inference(model) ceo = DualHeadCEO(model, tok) """ name = "dual" def __init__( self, model: Any, tokenizer: Any, action_adapter: str = "action", journal_adapter: str = "journal", action_max_tokens: int = 300, journal_max_tokens: int = 400, do_sample: bool = False, temperature: float = 1.0, verbose: bool = False, ): self.model = model self.tokenizer = tokenizer self.action_adapter = action_adapter self.journal_adapter = journal_adapter self.action_max_tokens = action_max_tokens self.journal_max_tokens = journal_max_tokens self.do_sample = do_sample self.temperature = temperature self.verbose = verbose self.n_action_parse_err = 0 self.n_journal_parse_err = 0 self.t_action_s = 0.0 self.t_journal_s = 0.0 self.n_action_tokens = 0 self.n_journal_tokens = 0 # Lazy import of prompts to keep this module free of torch deps for the # baseline-only code path. def _prompts(self): import prompts as P return P def _generate(self, chat, adapter, max_new): """Generate one completion; returns (text, n_new_tokens, wall_s).""" import time import torch self.model.set_adapter(adapter) prompt = self.tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True, ) enc = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) t0 = time.time() with torch.inference_mode(): out = self.model.generate( **enc, max_new_tokens=max_new, do_sample=self.do_sample, temperature=self.temperature, pad_token_id=self.tokenizer.pad_token_id, ) wall = time.time() - t0 n_new = int(out.shape[1] - enc.input_ids.shape[1]) text = self.tokenizer.decode( out[0, enc.input_ids.shape[1]:], skip_special_tokens=True, ) return text, n_new, wall def act( self, obs: SimMartObservation, env: Optional[SimMartEnvironment] = None, week: int = 0, ) -> SimMartAction: P = self._prompts() # --- Pass 1: action head -> decisions + budget ------------------------ action_chat = P.build_action_chat(obs) action_text, n_a, t_a = self._generate( action_chat, self.action_adapter, self.action_max_tokens, ) self.t_action_s += t_a self.n_action_tokens += n_a # Use the robust parser (request_info fallback). We explicitly ignore # any journal_entry the action head might have leaked into its JSON — # the journal pass owns that field. action, tel = P.parse_response(action_text, obs.inbox) if not tel["parse_ok"] and not tel["parse_partial"]: self.n_action_parse_err += 1 # --- Pass 2: journal head -> free-form retrospective ------------------ journal_chat = P.build_journal_chat( obs, action.decisions, action.budget_allocations, action.diligence_requests, ) journal_text, n_j, t_j = self._generate( journal_chat, self.journal_adapter, self.journal_max_tokens, ) self.t_journal_s += t_j self.n_journal_tokens += n_j journal = P.parse_journal_response(journal_text) if not journal: self.n_journal_parse_err += 1 if self.verbose: tok_s_a = n_a / t_a if t_a > 0 else 0 tok_s_j = n_j / t_j if t_j > 0 else 0 wk = week if week else 0 print( f" W{wk:>2} act={n_a}tok/{t_a:.1f}s ({tok_s_a:.0f}t/s) " f"jrn={n_j}tok/{t_j:.1f}s ({tok_s_j:.0f}t/s) " f"parse_ok={tel['parse_ok']}", flush=True, ) action.journal_entry = journal return action # --------------------------------------------------------------------------- # Frontier CEO — backed by a hosted frontier LLM # provider="openai" → any OpenAI-compatible Chat Completions endpoint # (HF router, Groq, Together, vLLM-served, …) # provider="anthropic" → Anthropic Messages API (direct, Bedrock-via-proxy, # corporate APIM proxy, etc.) # provider="openai_responses" → OpenAI Responses API behind an OpenAI-compatible # SDK (e.g. an enterprise APIM proxy that fronts # GPT-5 / o-series via the Responses contract) # --------------------------------------------------------------------------- def _parse_header_env(raw: str) -> Dict[str, str]: """Parse ANTHROPIC_CUSTOM_HEADERS which uses `Name:Value` per line.""" out: Dict[str, str] = {} for line in (raw or "").replace("\r\n", "\n").split("\n"): line = line.strip() if ":" in line: k, v = line.split(":", 1) out[k.strip()] = v.strip() return out class FrontierCEO(CEOPolicy): """CEO policy backed by a hosted frontier LLM. Uses the *exact same* prompt and JSON parser as the trained Qwen 1.5B, so the reward it achieves is an honest "what a big model sees when handed the same observation" ceiling. Does NOT read env.state — only the public observation. This keeps it at the same information level as the heuristic / trained policies. """ # Match the trainer's per-step generation budget so every policy — trained # or frontier — gets the same hard cap on response length. See # scripts/launch_grpo_ddp.sh (MAX_NEW) and the env's truncation behaviour. DEFAULT_MAX_TOKENS: int = 600 def __init__( self, model: Optional[str] = None, provider: str = "auto", # "auto" | "openai" | "anthropic" | "openai_responses" api_base: Optional[str] = None, api_key: Optional[str] = None, auth_token: Optional[str] = None, extra_headers: Optional[Dict[str, str]] = None, max_tokens: Optional[int] = None, temperature: float = 0.0, max_retries: int = 3, request_timeout_s: float = 90.0, budget_hint: bool = True, dual_head: bool = False, action_max_tokens: int = 300, journal_max_tokens: int = 400, permissive: bool = False, ): if max_tokens is None: max_tokens = self.DEFAULT_MAX_TOKENS provider = self._resolve_provider(provider, api_base, model) self._provider = provider if provider == "anthropic": self._init_anthropic( model=model, api_base=api_base, api_key=api_key, auth_token=auth_token, extra_headers=extra_headers, request_timeout_s=request_timeout_s, ) elif provider == "openai": self._init_openai( model=model, api_base=api_base, api_key=api_key, request_timeout_s=request_timeout_s, ) elif provider == "openai_responses": self._init_openai_responses( model=model, api_base=api_base, api_key=api_key, extra_headers=extra_headers, request_timeout_s=request_timeout_s, ) else: raise ValueError(f"Unknown provider: {provider!r}") self._max_tokens = max_tokens self._temperature = temperature self._max_retries = max_retries self._budget_hint = budget_hint self._dual_head = dual_head self._action_max_tokens = action_max_tokens self._journal_max_tokens = journal_max_tokens self._permissive = permissive tag = f"frontier:{self._model.split('/')[-1]}" if dual_head: tag += "-dual" if permissive: tag += "-permissive" self.name = tag self.n_parse_errors = 0 self.n_api_errors = 0 self.total_tokens = 0 self.total_prompt_tokens = 0 self.total_completion_tokens = 0 @staticmethod def _resolve_provider( provider: str, api_base: Optional[str], model: Optional[str] ) -> str: if provider != "auto": return provider if api_base and "router.huggingface.co" in api_base: return "openai" if model and ("claude" in model.lower()): return "anthropic" # GPT-ish model name → OpenAI Responses API endpoint if model and (model.lower().startswith("gpt-") or model.lower().startswith("o")): return "openai_responses" if os.environ.get("ANTHROPIC_BASE_URL"): return "anthropic" if os.environ.get("OPENAI_RESPONSES_BASE_URL"): return "openai_responses" if os.environ.get("HF_token") or os.environ.get("HF_TOKEN"): return "openai" raise RuntimeError( "FrontierCEO provider=auto: cannot infer. " "Pass provider=openai|anthropic|openai_responses explicitly." ) def _init_anthropic( self, model: Optional[str], api_base: Optional[str], api_key: Optional[str], auth_token: Optional[str], extra_headers: Optional[Dict[str, str]], request_timeout_s: float, ) -> None: try: import anthropic except ImportError as e: raise RuntimeError( "anthropic package required for provider=anthropic; pip install anthropic" ) from e base_url = api_base or os.environ.get("ANTHROPIC_BASE_URL") if not base_url: raise RuntimeError( "FrontierCEO(provider=anthropic) needs api_base or $ANTHROPIC_BASE_URL" ) token = ( auth_token or api_key or os.environ.get("ANTHROPIC_AUTH_TOKEN") or os.environ.get("ANTHROPIC_API_KEY") ) if not token: raise RuntimeError( "FrontierCEO(provider=anthropic) needs auth_token or $ANTHROPIC_AUTH_TOKEN" ) headers = dict(extra_headers or {}) env_headers = _parse_header_env(os.environ.get("ANTHROPIC_CUSTOM_HEADERS", "")) for k, v in env_headers.items(): headers.setdefault(k, v) self._client = anthropic.Anthropic( base_url=base_url, auth_token=token, default_headers=headers or None, timeout=request_timeout_s, ) self._model = ( model or os.environ.get("ANTHROPIC_DEFAULT_SONNET_MODEL") or os.environ.get("ANTHROPIC_MODEL") or "Claude-Sonnet-4.6" ) def _init_openai( self, model: Optional[str], api_base: Optional[str], api_key: Optional[str], request_timeout_s: float, ) -> None: try: from openai import OpenAI except ImportError as e: raise RuntimeError( "openai package required for provider=openai; pip install openai" ) from e key = api_key or os.environ.get("HF_token") or os.environ.get("HF_TOKEN") if not key: raise RuntimeError( "FrontierCEO(provider=openai) needs api_key or $HF_token" ) self._client = OpenAI( base_url=api_base or "https://router.huggingface.co/v1", api_key=key, timeout=request_timeout_s, ) self._model = model or "moonshotai/Kimi-K2.6" def _init_openai_responses( self, model: Optional[str], api_base: Optional[str], api_key: Optional[str], extra_headers: Optional[Dict[str, str]], request_timeout_s: float, ) -> None: """OpenAI Responses API via an OpenAI-compatible SDK. Useful for hitting OpenAI's hosted Responses API directly, or for an enterprise proxy that fronts GPT-5 / o-series models behind the Responses contract. Custom auth headers (e.g. APIM subscription keys or other gateway tokens) can be passed via $ANTHROPIC_CUSTOM_HEADERS or the ``extra_headers`` argument. The bearer token comes from ``api_key`` or $OPENAI_API_KEY (proxies that auth via a custom header can pass any non-empty placeholder). """ try: from openai import OpenAI except ImportError as e: raise RuntimeError( "openai package required for provider=openai_responses; " "pip install openai" ) from e base_url = api_base or os.environ.get("OPENAI_RESPONSES_BASE_URL") if not base_url: raise RuntimeError( "FrontierCEO(provider=openai_responses) needs api_base or " "$OPENAI_RESPONSES_BASE_URL" ) headers = dict(extra_headers or {}) env_headers = _parse_header_env(os.environ.get("ANTHROPIC_CUSTOM_HEADERS", "")) for k, v in env_headers.items(): headers.setdefault(k, v) key = ( api_key or os.environ.get("OPENAI_API_KEY") or "dummy" ) self._client = OpenAI( base_url=base_url, api_key=key, default_headers=headers or None, timeout=request_timeout_s, ) self._model = model or "gpt-5.4" def act(self, obs, env=None, week=0): if self._dual_head: return self._act_dual(obs, env=env, week=week) from prompts import build_chat, parse_response hint = self._max_tokens if self._budget_hint else None messages = build_chat(obs, token_budget=hint) completion = self._call(messages, max_tokens=self._max_tokens) action, tel = parse_response(completion, obs.inbox) if not tel["parse_ok"]: self.n_parse_errors += 1 return action def _act_dual(self, obs, env=None, week=0): """Two-pass frontier inference: action JSON call + journal text call. Mirrors the trained DualHeadCEO wire format so frontier numbers are directly comparable to SFT / GRPO dual-head adapters. When ``self._permissive`` is set, the action pass uses the permissive system prompt that allows deliberation before the block and gives concrete rogue + budget hints. Only the system prompt changes; the user content, the parser, and the journal pass are identical so numbers stay comparable to the strict variant. """ from prompts import ( ACTION_SYSTEM_PROMPT_PERMISSIVE, build_action_chat, build_journal_chat, parse_response, parse_journal_response, render_observation, ) # --- Pass 1: action --------------------------------------------------- act_messages = build_action_chat(obs) if self._permissive: act_messages = [ {"role": "system", "content": ACTION_SYSTEM_PROMPT_PERMISSIVE}, {"role": "user", "content": render_observation(obs)}, ] act_text = self._call(act_messages, max_tokens=self._action_max_tokens) action, tel = parse_response(act_text, obs.inbox) if not tel["parse_ok"] and not tel["parse_partial"]: self.n_parse_errors += 1 # --- Pass 2: journal -------------------------------------------------- jrn_messages = build_journal_chat( obs, action.decisions, action.budget_allocations, action.diligence_requests, ) jrn_text = self._call(jrn_messages, max_tokens=self._journal_max_tokens) action.journal_entry = parse_journal_response(jrn_text) return action def _call(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str: import time max_tokens = max_tokens if max_tokens is not None else self._max_tokens last_err: Optional[Exception] = None for attempt in range(self._max_retries): try: if self._provider == "anthropic": return self._call_anthropic(messages, max_tokens=max_tokens) if self._provider == "openai_responses": return self._call_openai_responses(messages, max_tokens=max_tokens) return self._call_openai(messages, max_tokens=max_tokens) except Exception as e: last_err = e if attempt + 1 < self._max_retries: backoff = 2 ** attempt print( f"[{self.name}] api retry {attempt+1}/{self._max_retries} " f"after {backoff}s: {type(e).__name__}: {str(e)[:140]}", file=sys.stderr, ) time.sleep(backoff) else: self.n_api_errors += 1 print( f"[{self.name}] giving up after {self._max_retries} retries: {last_err}", file=sys.stderr, ) return "" def _call_openai(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str: r = self._client.chat.completions.create( model=self._model, messages=messages, max_tokens=max_tokens or self._max_tokens, temperature=self._temperature, ) if r.usage is not None: self.total_tokens += r.usage.total_tokens or 0 self.total_prompt_tokens += r.usage.prompt_tokens or 0 self.total_completion_tokens += r.usage.completion_tokens or 0 return r.choices[0].message.content or "" def _call_openai_responses(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str: # OpenAI Responses API. System prompt can travel inside `input` — we # pass the full messages list unchanged and rely on `output_text` to # flatten the result back to a string. kwargs: Dict[str, Any] = { "model": self._model, "input": messages, "max_output_tokens": max_tokens or self._max_tokens, } # gpt-5.* reasoning models reject a non-default temperature if self._temperature is not None and not self._model.lower().startswith("gpt-5"): kwargs["temperature"] = self._temperature r = self._client.responses.create(**kwargs) usage = getattr(r, "usage", None) if usage is not None: in_tok = getattr(usage, "input_tokens", 0) or 0 out_tok = getattr(usage, "output_tokens", 0) or 0 self.total_prompt_tokens += in_tok self.total_completion_tokens += out_tok self.total_tokens += in_tok + out_tok return r.output_text or "" def _call_anthropic(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str: # Anthropic wants system prompt separate from the user/assistant turns. system_parts = [m["content"] for m in messages if m.get("role") == "system"] turns = [m for m in messages if m.get("role") != "system"] system = "\n\n".join(system_parts) if system_parts else None r = self._client.messages.create( model=self._model, max_tokens=max_tokens or self._max_tokens, temperature=self._temperature, system=system, messages=turns, ) usage = getattr(r, "usage", None) if usage is not None: in_tok = getattr(usage, "input_tokens", 0) or 0 out_tok = getattr(usage, "output_tokens", 0) or 0 self.total_prompt_tokens += in_tok self.total_completion_tokens += out_tok self.total_tokens += in_tok + out_tok text_parts: List[str] = [] for block in r.content: if getattr(block, "type", "") == "text": text_parts.append(block.text) return "".join(text_parts) # --------------------------------------------------------------------------- # Rollout runner # --------------------------------------------------------------------------- @dataclass class EpisodeResult: policy: str seed: int total_reward: float weekly_rewards: List[float] final_cash_inr: float revenue_qtd_inr: float ebitda_qtd_inr: float ebitda_margin_pct: float min_cash_inr: float rogues_total: int rogues_caught: int avg_stockout_pct: float avg_nps: float trace: List[Dict] = None def run_one_episode( policy: CEOPolicy, seed: int, collect_trace: bool = False, verbose: bool = False, ) -> EpisodeResult: env = SimMartEnvironment() obs = env.reset(seed=seed, episode_id=f"{policy.name}-{seed}") weekly_rewards: List[float] = [] stockouts: List[float] = [] npss: List[float] = [] trace: List[Dict] = [] if collect_trace else None min_cash = env.state.company.cash_inr for w in range(1, env.MAX_WEEKS + 1): action = policy.act(obs, env=env, week=w) step_obs = env.step(action) r = step_obs.reward or 0.0 weekly_rewards.append(r) stockouts.append(step_obs.kpi_snapshot.stockout_rate_pct) npss.append(step_obs.kpi_snapshot.nps) min_cash = min(min_cash, env.state.company.cash_inr) if collect_trace: trace.append({ "week": w, "day": step_obs.day_of_quarter, "inbox_size": len(obs.inbox), "decisions": [d.model_dump() for d in action.decisions], "budget_allocations": action.budget_allocations, "journal": action.journal_entry, "reward": r, "kpi": step_obs.kpi_snapshot.model_dump(), "active_crises": [c.crisis_id for c in step_obs.active_crises], "pnl_qtd": step_obs.pnl_snapshot.model_dump(), }) if verbose: print( f" W{w:2d} r={r:+.3f} rev=₹{step_obs.kpi_snapshot.revenue_inr/1e7:.2f}Cr " f"margin={step_obs.kpi_snapshot.gross_margin_pct:5.2f}% NPS={step_obs.kpi_snapshot.nps:4.1f} " f"stockout={step_obs.kpi_snapshot.stockout_rate_pct:5.1f}% " f"cash=₹{env.state.company.cash_inr/1e7:+.2f}Cr" ) obs = step_obs if obs.done: break return EpisodeResult( policy=policy.name, seed=seed, total_reward=sum(weekly_rewards), weekly_rewards=weekly_rewards, final_cash_inr=env.state.company.cash_inr, revenue_qtd_inr=env.state.company.pnl_qtd.revenue_qtd_inr, ebitda_qtd_inr=env.state.company.pnl_qtd.ebitda_qtd_inr, ebitda_margin_pct=env.state.company.pnl_qtd.ebitda_margin_pct, min_cash_inr=min_cash, rogues_total=len(env.state.rogue_incidents), rogues_caught=sum(1 for r in env.state.rogue_incidents if r.caught), avg_stockout_pct=statistics.mean(stockouts) if stockouts else 0.0, avg_nps=statistics.mean(npss) if npss else 0.0, trace=trace, ) def run_policy( policy: CEOPolicy, seeds: List[int], verbose: bool = False, quiet: bool = False, ) -> List[EpisodeResult]: results: List[EpisodeResult] = [] for seed in seeds: if not quiet: print(f"\n[{policy.name}] seed={seed} rollout …") res = run_one_episode(policy, seed, collect_trace=False, verbose=verbose) results.append(res) if not quiet: print( f" → total reward {res.total_reward:+.3f}, " f"EBITDA ₹{res.ebitda_qtd_inr/1e7:+.2f}Cr ({res.ebitda_margin_pct:+.1f}%), " f"rogues {res.rogues_caught}/{res.rogues_total}, " f"avg stockout {res.avg_stockout_pct:.1f}%" ) return results def summarise(results_by_policy: Dict[str, List[EpisodeResult]]) -> None: """Print a nicely-formatted comparison table.""" header = ( f"{'policy':<11} {'n':>3} " f"{'tot_r (mean±sd)':>18} {'EBITDA% (mean)':>14} " f"{'rogue_recall':>13} {'avg_stockout':>13} {'avg_NPS':>8}" ) print("\n" + header) print("-" * len(header)) for name in results_by_policy.keys(): res = results_by_policy.get(name) if not res: continue rewards = [r.total_reward for r in res] mean_r = statistics.mean(rewards) sd_r = statistics.stdev(rewards) if len(rewards) > 1 else 0.0 ebitda_pct = statistics.mean([r.ebitda_margin_pct for r in res]) recalls = [ (r.rogues_caught / r.rogues_total) if r.rogues_total else 1.0 for r in res ] recall = statistics.mean(recalls) stockout = statistics.mean([r.avg_stockout_pct for r in res]) nps = statistics.mean([r.avg_nps for r in res]) print( f"{name:<11} {len(res):>3} " f"{mean_r:+7.3f} ± {sd_r:5.3f} " f"{ebitda_pct:+13.2f} " f"{recall:12.1%} " f"{stockout:12.1f} " f"{nps:7.1f}" ) def main() -> int: parser = argparse.ArgumentParser(description="SimMart baseline CEO rollouts") parser.add_argument("--policy", choices=["random", "all_approve", "heuristic", "oracle", "god", "frontier", "all"], default="all") parser.add_argument("--frontier-provider", choices=["auto", "openai", "anthropic", "openai_responses"], default="auto", help="auto picks openai_responses for gpt-*, anthropic for claude-*, else openai") parser.add_argument("--frontier-model", default=None, help="Model id (default: Claude-Sonnet-4.6 for anthropic, " "moonshotai/Kimi-K2.6 for openai)") parser.add_argument("--frontier-api-base", default=None, help="API base URL (default: $ANTHROPIC_BASE_URL for anthropic, " "https://router.huggingface.co/v1 for openai)") parser.add_argument("--frontier-temperature", type=float, default=0.0) parser.add_argument("--frontier-max-tokens", type=int, default=FrontierCEO.DEFAULT_MAX_TOKENS, help="hard cap on response tokens; matches trainer MAX_NEW (600)") parser.add_argument("--frontier-no-budget-hint", action="store_true", help="do NOT tell the frontier model its token budget") parser.add_argument("--seed", type=int, default=None, help="Seed for a single-episode run (ignored if --episodes > 1)") parser.add_argument("--episodes", type=int, default=3, help="Number of episodes per policy; seeds = [1, 2, …, N]") parser.add_argument("--verbose", action="store_true", help="Print per-week telemetry") parser.add_argument("--quiet", action="store_true", help="Only print summary table") parser.add_argument("--trace", type=str, default=None, help="Dump a JSON trace of a single episode for analysis (uses --seed, single policy only)") args = parser.parse_args() # --- Trace mode: single episode, JSON dump --- if args.trace: if args.policy == "all": print("--trace requires a single --policy", file=sys.stderr) return 2 policy_cls = { "random": RandomCEO, "all_approve": AllApproveCEO, "heuristic": HeuristicCEO, "oracle": OracleCEO, "god": GodCEO, "frontier": FrontierCEO, }[args.policy] if args.policy == "random": policy = policy_cls(seed=args.seed or 0) elif args.policy == "frontier": policy = FrontierCEO( provider=args.frontier_provider, model=args.frontier_model, api_base=args.frontier_api_base, temperature=args.frontier_temperature, max_tokens=args.frontier_max_tokens, budget_hint=not args.frontier_no_budget_hint, ) else: policy = policy_cls() seed = args.seed if args.seed is not None else 42 print(f"[trace] {args.policy} seed={seed} → {args.trace}") res = run_one_episode(policy, seed, collect_trace=True, verbose=args.verbose) payload = { "meta": { "policy": res.policy, "seed": res.seed, "total_reward": res.total_reward, "final_cash_inr": res.final_cash_inr, "ebitda_qtd_inr": res.ebitda_qtd_inr, "ebitda_margin_pct": res.ebitda_margin_pct, "rogues": {"total": res.rogues_total, "caught": res.rogues_caught}, "avg_stockout_pct": res.avg_stockout_pct, "avg_nps": res.avg_nps, }, "trace": res.trace, } with open(args.trace, "w") as f: json.dump(payload, f, indent=2, default=str) print(f" ✓ wrote {len(res.trace)} weekly steps") return 0 # --- Sweep mode --- seeds = [args.seed] if args.seed is not None and args.episodes == 1 \ else list(range(1, args.episodes + 1)) policies: List[CEOPolicy] = [] if args.policy in ("random", "all"): policies.append(RandomCEO(seed=0)) if args.policy in ("all_approve", "all"): policies.append(AllApproveCEO()) if args.policy in ("heuristic", "all"): policies.append(HeuristicCEO()) if args.policy in ("oracle", "all"): policies.append(OracleCEO()) if args.policy in ("god", "all"): policies.append(GodCEO()) if args.policy == "frontier": policies.append(FrontierCEO( provider=args.frontier_provider, model=args.frontier_model, api_base=args.frontier_api_base, temperature=args.frontier_temperature, max_tokens=args.frontier_max_tokens, budget_hint=not args.frontier_no_budget_hint, )) results_by_policy: Dict[str, List[EpisodeResult]] = {} for p in policies: results_by_policy[p.name] = run_policy(p, seeds, verbose=args.verbose, quiet=args.quiet) summarise(results_by_policy) return 0 if __name__ == "__main__": raise SystemExit(main())