| """Baseline CEO policies for SimMart — random, heuristic, oracle. |
| |
| Used for: |
| • Smoke-testing the environment end-to-end before GRPO training |
| • Establishing a reward floor (random), ceiling (oracle), and a sensible |
| non-learned target (heuristic) that the trained Qwen CEO must beat |
| • Generating the "before vs after training" contrast trace for the demo |
| |
| Usage: |
| python inference.py --policy heuristic --seed 42 --episodes 3 --verbose |
| python inference.py --policy all --episodes 10 --quiet |
| python inference.py --trace hero_episode.json --policy oracle --seed 7 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import random |
| import statistics |
| import sys |
| from dataclasses import dataclass |
| from typing import Any, Dict, List, Optional |
|
|
| |
| HERE = os.path.dirname(os.path.abspath(__file__)) |
| if HERE not in sys.path: |
| sys.path.insert(0, HERE) |
|
|
| from models import ( |
| CrisisEvent, |
| Proposal, |
| ProposalDecision, |
| SimMartAction, |
| SimMartObservation, |
| ) |
| from server.environment import SimMartEnvironment |
|
|
|
|
| |
| |
| |
|
|
| class CEOPolicy: |
| """Interface: given an obs, return a SimMartAction for this week. |
| |
| Policies may inspect `env.state` (oracle) or not (random/heuristic). |
| """ |
|
|
| name: str = "base" |
|
|
| def act( |
| self, |
| obs: SimMartObservation, |
| env: Optional[SimMartEnvironment] = None, |
| week: int = 0, |
| ) -> SimMartAction: |
| raise NotImplementedError |
|
|
|
|
| |
| |
| |
|
|
| class RandomCEO(CEOPolicy): |
| name = "random" |
|
|
| def __init__(self, seed: int = 0): |
| self._rng = random.Random(seed) |
|
|
| def act(self, obs, env=None, week=0): |
| verdicts = ["approve", "reject", "flag_suspicious"] |
| weights = [0.65, 0.25, 0.10] |
| decisions = [ |
| ProposalDecision( |
| proposal_id=p.proposal_id, |
| verdict=self._rng.choices(verdicts, weights=weights)[0], |
| ) |
| for p in obs.inbox |
| ] |
| return SimMartAction( |
| decisions=decisions, |
| budget_allocations={"supply_chain": 5e6, "store_ops": 5e5, "finance": 1e6, "growth": 1e6}, |
| journal_entry=f"Random CEO, week {week}. No pattern.", |
| ) |
|
|
|
|
| |
| |
| |
| |
|
|
| class AllApproveCEO(CEOPolicy): |
| name = "all_approve" |
|
|
| def act(self, obs, env=None, week=0): |
| decisions = [ |
| ProposalDecision(proposal_id=p.proposal_id, verdict="approve") |
| for p in obs.inbox |
| ] |
| return SimMartAction( |
| decisions=decisions, |
| budget_allocations={"supply_chain": 1e7, "store_ops": 2e6, "finance": 1e6, "growth": 2e6}, |
| journal_entry=f"All-approve CEO, week {week}.", |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class HeuristicCEO(CEOPolicy): |
| name = "heuristic" |
|
|
| def __init__(self, cash_floor_inr: float = 3e7): |
| self._cash_floor = cash_floor_inr |
|
|
| @staticmethod |
| def _crisis_flags(active: List[CrisisEvent]) -> Dict[str, bool]: |
| f: Dict[str, bool] = {} |
| for c in active: |
| f[c.crisis_id] = True |
| return f |
|
|
| def act(self, obs, env=None, week=0): |
| decisions: List[ProposalDecision] = [] |
| cash = obs.kpi_snapshot.cash_inr |
| tight_cash = cash < self._cash_floor |
| crisis_active = self._crisis_flags(obs.active_crises) |
|
|
| for p in obs.inbox: |
| verdict = self._decide(p, tight_cash, crisis_active) |
| decisions.append(ProposalDecision(proposal_id=p.proposal_id, verdict=verdict)) |
|
|
| |
| if tight_cash: |
| budget = {"supply_chain": 5e6, "store_ops": 5e5, "finance": 5e5, "growth": 5e5} |
| else: |
| budget = {"supply_chain": 1e7, "store_ops": 2e6, "finance": 1e6, "growth": 2e6} |
|
|
| journal = self._write_journal(obs, decisions, tight_cash, week) |
| return SimMartAction(decisions=decisions, budget_allocations=budget, journal_entry=journal) |
|
|
| |
| |
| |
| SPEND_CEILINGS_INR: Dict[str, float] = { |
| "campaign.launch": 15_00_000, |
| "capex.approve": 10_00_000, |
| "return.approve": 50_000, |
| "brand.ambassador": 5_00_000, |
| "loyalty.update": 5_00_000, |
| "hours.extend": 50_000, |
| } |
|
|
| @classmethod |
| def _decide( |
| cls, |
| p: Proposal, |
| tight_cash: bool, |
| crisis_active: Dict[str, bool], |
| ) -> str: |
| """Rule-based verdict. |
| |
| Reliable rogue tell from the obs alone: |
| R2 ghost vendor : vendor_id starts with 'V-SUSPICIOUS-' |
| Other rogue tells (R1 qty inflation, R3 cost inflation) need |
| baselines not in the obs — Oracle gets those via env.state. |
| |
| Negative-EV reject rules (4dept-reject env): |
| campaign/capex/refund/loyalty/ambassador/hours.extend get |
| rejected when their spend exceeds a per-action ceiling, since |
| the opex hit to EBITDA outweighs the modeled benefit. |
| """ |
| params = p.params or {} |
|
|
| vendor_id = str(params.get("vendor_id", "") or "") |
| if vendor_id.upper().startswith("V-SUSPICIOUS"): |
| return "flag_suspicious" |
|
|
| ceiling = cls.SPEND_CEILINGS_INR.get(p.action) |
| if ceiling is not None: |
| spend = float( |
| params.get("spend_inr") |
| or params.get("amount_inr") |
| or params.get("refund_inr") |
| or params.get("perk_cost_inr") |
| or params.get("cost_inr") |
| or (params.get("hours", 0) * 60 * 100 if p.action == "hours.extend" else 0) |
| or 0.0 |
| ) |
| if spend > ceiling: |
| return "reject" |
|
|
| if p.urgency == "high": |
| return "approve" |
|
|
| if p.urgency == "low" and tight_cash: |
| return "reject" |
|
|
| return "approve" |
|
|
| @staticmethod |
| def _write_journal(obs, decisions, tight_cash, week): |
| n_flag = sum(1 for d in decisions if d.verdict == "flag_suspicious") |
| n_reject = sum(1 for d in decisions if d.verdict == "reject") |
| n_approve = sum(1 for d in decisions if d.verdict == "approve") |
| kpi = obs.kpi_snapshot |
| active = [f"{c.crisis_id}:{c.name}" for c in obs.active_crises] |
| return ( |
| f"Week {week}: {n_approve} approved, {n_reject} rejected, {n_flag} flagged. " |
| f"Cash ₹{kpi.cash_inr/1e7:+.2f} Cr (cash-tight={tight_cash}), " |
| f"NPS {kpi.nps:.0f}, stockout {kpi.stockout_rate_pct:.1f}%, " |
| f"SLA {kpi.delivery_sla_hit_rate_pct:.0f}%. " |
| f"Active crises: {', '.join(active) or 'none'}. " |
| f"Next week: watch cash and supply-chain replenishment." |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class OracleCEO(HeuristicCEO): |
| """Behaves like the heuristic but with two cheats: |
| |
| 1. Flags every rogue-associated proposal this week (via env.state) |
| 2. Drives up approvals when NOT tight-cash (to avoid false rejects) |
| """ |
|
|
| name = "oracle" |
|
|
| def act(self, obs, env=None, week=0): |
| if env is None: |
| return super().act(obs, env, week) |
|
|
| |
| rogue_proposal_ids = set() |
| for r in env.state.rogue_incidents: |
| if week in r.active_weeks: |
| rogue_proposal_ids.update(r.associated_proposal_ids) |
|
|
| action = super().act(obs, env, week) |
| patched: List[ProposalDecision] = [] |
| for d in action.decisions: |
| if d.proposal_id in rogue_proposal_ids: |
| patched.append(ProposalDecision( |
| proposal_id=d.proposal_id, |
| verdict="flag_suspicious", |
| flag_reason="Rogue indicator: telltale pattern in proposal metadata.", |
| )) |
| else: |
| patched.append(d) |
|
|
| action.decisions = patched |
| action.journal_entry = action.journal_entry + " (Oracle override: flagged known rogues)." |
| return action |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| class GodCEO(CEOPolicy): |
| name = "rule_oracle" |
|
|
| def __init__(self, cash_floor_inr: float = 3e7): |
| self._cash_floor = cash_floor_inr |
| self._last_journal: str = "" |
|
|
| @staticmethod |
| def _active_rogue_proposal_ids(env: SimMartEnvironment, week: int) -> set: |
| ids: set = set() |
| for r in env.state.rogue_incidents: |
| if week in r.active_weeks: |
| ids.update(r.associated_proposal_ids) |
| return ids |
|
|
| def act(self, obs, env=None, week=0): |
| if env is None: |
| return HeuristicCEO().act(obs, env, week) |
|
|
| rogue_ids = self._active_rogue_proposal_ids(env, week) |
| cash = obs.kpi_snapshot.cash_inr |
| tight_cash = cash < self._cash_floor |
| crisis_active = {c.crisis_id: True for c in obs.active_crises} |
|
|
| decisions: List[ProposalDecision] = [] |
| for p in obs.inbox: |
| if p.proposal_id in rogue_ids: |
| |
| decisions.append(ProposalDecision( |
| proposal_id=p.proposal_id, |
| verdict="flag_suspicious", |
| flag_reason="Rogue indicator: telltale pattern in proposal metadata.", |
| )) |
| continue |
|
|
| |
| |
| |
| verdict = "approve" |
| params = p.params or {} |
| ceiling = HeuristicCEO.SPEND_CEILINGS_INR.get(p.action) |
| if ceiling is not None: |
| spend = float( |
| params.get("spend_inr") |
| or params.get("amount_inr") |
| or params.get("refund_inr") |
| or params.get("perk_cost_inr") |
| or params.get("cost_inr") |
| or (params.get("hours", 0) * 60 * 100 if p.action == "hours.extend" else 0) |
| or 0.0 |
| ) |
| if spend > ceiling: |
| verdict = "reject" |
| if verdict == "approve" and p.urgency == "low" and tight_cash: |
| verdict = "reject" |
| decisions.append(ProposalDecision(proposal_id=p.proposal_id, verdict=verdict)) |
|
|
| |
| |
| if tight_cash: |
| budget = {"supply_chain": 6e6, "store_ops": 1e6, "finance": 5e5, "growth": 5e5} |
| else: |
| budget = {"supply_chain": 1.2e7, "store_ops": 2.5e6, "finance": 1e6, "growth": 2e6} |
|
|
| journal = self._write_max_coherence_journal( |
| obs, decisions, rogue_ids, tight_cash, week |
| ) |
| self._last_journal = journal |
| return SimMartAction( |
| decisions=decisions, |
| budget_allocations=budget, |
| journal_entry=journal, |
| ) |
|
|
| def _write_max_coherence_journal( |
| self, |
| obs, |
| decisions: List[ProposalDecision], |
| rogue_ids: set, |
| tight_cash: bool, |
| week: int, |
| ) -> str: |
| """Free-text retrospective journal for the 2-dept mini env. |
| |
| Mentions both depts, keeps continuity with last week, and references |
| proposal IDs so a future judge can trace the decision rationale. |
| """ |
| approved_or_flagged = [ |
| d.proposal_id for d in decisions |
| if d.verdict in ("approve", "flag_suspicious") |
| ] |
| top_pids = approved_or_flagged[:4] |
| flagged = [d.proposal_id for d in decisions if d.verdict == "flag_suspicious"] |
| flagged_str = flagged[0] if flagged else "" |
|
|
| kpi = obs.kpi_snapshot |
| active = [f"{c.crisis_id}:{c.name}" for c in obs.active_crises] or ["none"] |
|
|
| |
| anchor = "SimMart operating cadence" |
| if self._last_journal: |
| for tok in ("SimMart", "supply", "store", "margin", "stockout"): |
| if tok.lower() in self._last_journal.lower(): |
| anchor = tok |
| break |
|
|
| body = ( |
| f"Week {week} CEO decision log for SimMart, continuing last week's focus on {anchor}. " |
| f"Priority this week was the supply_chain pipeline and store_ops SLA. " |
| f"Approvals ({len([d for d in decisions if d.verdict=='approve'])}): " |
| f"cleared {', '.join(top_pids[:3]) or 'the routine POs'} to keep stock cover healthy " |
| f"and drive next-week revenue. " |
| f"Flagged ({len(flagged)}): {flagged_str or 'no rogue signal this week'} — " |
| f"risk pattern matched our rogue playbook (inflated PO qty / ghost vendor / unit-cost kickback). " |
| f"Rejections: only low-urgency items when cash was tight. " |
| f"KPI pulse: cash ₹{kpi.cash_inr/1e7:+.2f} Cr (tight={tight_cash}), " |
| f"NPS {kpi.nps:.0f} (Δ{kpi.nps_delta:+.1f}), " |
| f"stockout {kpi.stockout_rate_pct:.1f}% (Δ{kpi.stockout_delta_pts:+.2f} pp), " |
| f"margin {kpi.gross_margin_pct:.2f}% (Δ{kpi.margin_delta_pts:+.2f} pp). " |
| f"Action items: reinforce supply_chain lead-time discipline and store_ops staffing for the festive week; " |
| f"keep cash runway above ₹{self._cash_floor/1e7:.0f} Cr. " |
| f"Crises in play: {', '.join(active)}. " |
| f"Decision rationale: approve legitimate work, flag_suspicious only where the rogue tell is explicit. " |
| f"Next week: monitor stockout trajectory, audit flagged proposal " |
| f"{flagged_str or (top_pids[-1] if top_pids else 'OPS')}, " |
| f"tighten journal continuity with {anchor}." |
| ) |
| return body |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| class DualHeadCEO(CEOPolicy): |
| """Two-pass CEO over a single base model with swappable LoRA adapters. |
| |
| Construct with a preloaded HF/PEFT model that has two adapters registered |
| (via ``model.load_adapter(path, adapter_name=name)``) and this class |
| handles the per-step swap + generation. |
| |
| Example:: |
| |
| model, tok = FastLanguageModel.from_pretrained(...) |
| model.load_adapter(action_adapter_path, adapter_name="action") |
| model.load_adapter(journal_adapter_path, adapter_name="journal") |
| FastLanguageModel.for_inference(model) |
| ceo = DualHeadCEO(model, tok) |
| """ |
|
|
| name = "dual" |
|
|
| def __init__( |
| self, |
| model: Any, |
| tokenizer: Any, |
| action_adapter: str = "action", |
| journal_adapter: str = "journal", |
| action_max_tokens: int = 300, |
| journal_max_tokens: int = 400, |
| do_sample: bool = False, |
| temperature: float = 1.0, |
| verbose: bool = False, |
| ): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.action_adapter = action_adapter |
| self.journal_adapter = journal_adapter |
| self.action_max_tokens = action_max_tokens |
| self.journal_max_tokens = journal_max_tokens |
| self.do_sample = do_sample |
| self.temperature = temperature |
| self.verbose = verbose |
| self.n_action_parse_err = 0 |
| self.n_journal_parse_err = 0 |
| self.t_action_s = 0.0 |
| self.t_journal_s = 0.0 |
| self.n_action_tokens = 0 |
| self.n_journal_tokens = 0 |
|
|
| |
| |
| def _prompts(self): |
| import prompts as P |
| return P |
|
|
| def _generate(self, chat, adapter, max_new): |
| """Generate one completion; returns (text, n_new_tokens, wall_s).""" |
| import time |
| import torch |
| self.model.set_adapter(adapter) |
| prompt = self.tokenizer.apply_chat_template( |
| chat, tokenize=False, add_generation_prompt=True, |
| ) |
| enc = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
| t0 = time.time() |
| with torch.inference_mode(): |
| out = self.model.generate( |
| **enc, |
| max_new_tokens=max_new, |
| do_sample=self.do_sample, |
| temperature=self.temperature, |
| pad_token_id=self.tokenizer.pad_token_id, |
| ) |
| wall = time.time() - t0 |
| n_new = int(out.shape[1] - enc.input_ids.shape[1]) |
| text = self.tokenizer.decode( |
| out[0, enc.input_ids.shape[1]:], skip_special_tokens=True, |
| ) |
| return text, n_new, wall |
|
|
| def act( |
| self, |
| obs: SimMartObservation, |
| env: Optional[SimMartEnvironment] = None, |
| week: int = 0, |
| ) -> SimMartAction: |
| P = self._prompts() |
|
|
| |
| action_chat = P.build_action_chat(obs) |
| action_text, n_a, t_a = self._generate( |
| action_chat, self.action_adapter, self.action_max_tokens, |
| ) |
| self.t_action_s += t_a |
| self.n_action_tokens += n_a |
| |
| |
| |
| action, tel = P.parse_response(action_text, obs.inbox) |
| if not tel["parse_ok"] and not tel["parse_partial"]: |
| self.n_action_parse_err += 1 |
|
|
| |
| journal_chat = P.build_journal_chat( |
| obs, action.decisions, action.budget_allocations, action.diligence_requests, |
| ) |
| journal_text, n_j, t_j = self._generate( |
| journal_chat, self.journal_adapter, self.journal_max_tokens, |
| ) |
| self.t_journal_s += t_j |
| self.n_journal_tokens += n_j |
| journal = P.parse_journal_response(journal_text) |
| if not journal: |
| self.n_journal_parse_err += 1 |
|
|
| if self.verbose: |
| tok_s_a = n_a / t_a if t_a > 0 else 0 |
| tok_s_j = n_j / t_j if t_j > 0 else 0 |
| wk = week if week else 0 |
| print( |
| f" W{wk:>2} act={n_a}tok/{t_a:.1f}s ({tok_s_a:.0f}t/s) " |
| f"jrn={n_j}tok/{t_j:.1f}s ({tok_s_j:.0f}t/s) " |
| f"parse_ok={tel['parse_ok']}", |
| flush=True, |
| ) |
|
|
| action.journal_entry = journal |
| return action |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def _parse_header_env(raw: str) -> Dict[str, str]: |
| """Parse ANTHROPIC_CUSTOM_HEADERS which uses `Name:Value` per line.""" |
| out: Dict[str, str] = {} |
| for line in (raw or "").replace("\r\n", "\n").split("\n"): |
| line = line.strip() |
| if ":" in line: |
| k, v = line.split(":", 1) |
| out[k.strip()] = v.strip() |
| return out |
|
|
|
|
| class FrontierCEO(CEOPolicy): |
| """CEO policy backed by a hosted frontier LLM. |
| |
| Uses the *exact same* prompt and JSON parser as the trained Qwen 1.5B, |
| so the reward it achieves is an honest "what a big model sees when |
| handed the same observation" ceiling. |
| |
| Does NOT read env.state — only the public observation. This keeps it |
| at the same information level as the heuristic / trained policies. |
| """ |
|
|
| |
| |
| |
| DEFAULT_MAX_TOKENS: int = 600 |
|
|
| def __init__( |
| self, |
| model: Optional[str] = None, |
| provider: str = "auto", |
| api_base: Optional[str] = None, |
| api_key: Optional[str] = None, |
| auth_token: Optional[str] = None, |
| extra_headers: Optional[Dict[str, str]] = None, |
| max_tokens: Optional[int] = None, |
| temperature: float = 0.0, |
| max_retries: int = 3, |
| request_timeout_s: float = 90.0, |
| budget_hint: bool = True, |
| dual_head: bool = False, |
| action_max_tokens: int = 300, |
| journal_max_tokens: int = 400, |
| permissive: bool = False, |
| ): |
| if max_tokens is None: |
| max_tokens = self.DEFAULT_MAX_TOKENS |
| provider = self._resolve_provider(provider, api_base, model) |
| self._provider = provider |
|
|
| if provider == "anthropic": |
| self._init_anthropic( |
| model=model, |
| api_base=api_base, |
| api_key=api_key, |
| auth_token=auth_token, |
| extra_headers=extra_headers, |
| request_timeout_s=request_timeout_s, |
| ) |
| elif provider == "openai": |
| self._init_openai( |
| model=model, |
| api_base=api_base, |
| api_key=api_key, |
| request_timeout_s=request_timeout_s, |
| ) |
| elif provider == "openai_responses": |
| self._init_openai_responses( |
| model=model, |
| api_base=api_base, |
| api_key=api_key, |
| extra_headers=extra_headers, |
| request_timeout_s=request_timeout_s, |
| ) |
| else: |
| raise ValueError(f"Unknown provider: {provider!r}") |
|
|
| self._max_tokens = max_tokens |
| self._temperature = temperature |
| self._max_retries = max_retries |
| self._budget_hint = budget_hint |
| self._dual_head = dual_head |
| self._action_max_tokens = action_max_tokens |
| self._journal_max_tokens = journal_max_tokens |
| self._permissive = permissive |
| tag = f"frontier:{self._model.split('/')[-1]}" |
| if dual_head: |
| tag += "-dual" |
| if permissive: |
| tag += "-permissive" |
| self.name = tag |
| self.n_parse_errors = 0 |
| self.n_api_errors = 0 |
| self.total_tokens = 0 |
| self.total_prompt_tokens = 0 |
| self.total_completion_tokens = 0 |
|
|
| @staticmethod |
| def _resolve_provider( |
| provider: str, api_base: Optional[str], model: Optional[str] |
| ) -> str: |
| if provider != "auto": |
| return provider |
| if api_base and "router.huggingface.co" in api_base: |
| return "openai" |
| if model and ("claude" in model.lower()): |
| return "anthropic" |
| |
| if model and (model.lower().startswith("gpt-") or model.lower().startswith("o")): |
| return "openai_responses" |
| if os.environ.get("ANTHROPIC_BASE_URL"): |
| return "anthropic" |
| if os.environ.get("OPENAI_RESPONSES_BASE_URL"): |
| return "openai_responses" |
| if os.environ.get("HF_token") or os.environ.get("HF_TOKEN"): |
| return "openai" |
| raise RuntimeError( |
| "FrontierCEO provider=auto: cannot infer. " |
| "Pass provider=openai|anthropic|openai_responses explicitly." |
| ) |
|
|
| def _init_anthropic( |
| self, |
| model: Optional[str], |
| api_base: Optional[str], |
| api_key: Optional[str], |
| auth_token: Optional[str], |
| extra_headers: Optional[Dict[str, str]], |
| request_timeout_s: float, |
| ) -> None: |
| try: |
| import anthropic |
| except ImportError as e: |
| raise RuntimeError( |
| "anthropic package required for provider=anthropic; pip install anthropic" |
| ) from e |
| base_url = api_base or os.environ.get("ANTHROPIC_BASE_URL") |
| if not base_url: |
| raise RuntimeError( |
| "FrontierCEO(provider=anthropic) needs api_base or $ANTHROPIC_BASE_URL" |
| ) |
| token = ( |
| auth_token |
| or api_key |
| or os.environ.get("ANTHROPIC_AUTH_TOKEN") |
| or os.environ.get("ANTHROPIC_API_KEY") |
| ) |
| if not token: |
| raise RuntimeError( |
| "FrontierCEO(provider=anthropic) needs auth_token or $ANTHROPIC_AUTH_TOKEN" |
| ) |
| headers = dict(extra_headers or {}) |
| env_headers = _parse_header_env(os.environ.get("ANTHROPIC_CUSTOM_HEADERS", "")) |
| for k, v in env_headers.items(): |
| headers.setdefault(k, v) |
| self._client = anthropic.Anthropic( |
| base_url=base_url, |
| auth_token=token, |
| default_headers=headers or None, |
| timeout=request_timeout_s, |
| ) |
| self._model = ( |
| model |
| or os.environ.get("ANTHROPIC_DEFAULT_SONNET_MODEL") |
| or os.environ.get("ANTHROPIC_MODEL") |
| or "Claude-Sonnet-4.6" |
| ) |
|
|
| def _init_openai( |
| self, |
| model: Optional[str], |
| api_base: Optional[str], |
| api_key: Optional[str], |
| request_timeout_s: float, |
| ) -> None: |
| try: |
| from openai import OpenAI |
| except ImportError as e: |
| raise RuntimeError( |
| "openai package required for provider=openai; pip install openai" |
| ) from e |
| key = api_key or os.environ.get("HF_token") or os.environ.get("HF_TOKEN") |
| if not key: |
| raise RuntimeError( |
| "FrontierCEO(provider=openai) needs api_key or $HF_token" |
| ) |
| self._client = OpenAI( |
| base_url=api_base or "https://router.huggingface.co/v1", |
| api_key=key, |
| timeout=request_timeout_s, |
| ) |
| self._model = model or "moonshotai/Kimi-K2.6" |
|
|
| def _init_openai_responses( |
| self, |
| model: Optional[str], |
| api_base: Optional[str], |
| api_key: Optional[str], |
| extra_headers: Optional[Dict[str, str]], |
| request_timeout_s: float, |
| ) -> None: |
| """OpenAI Responses API via an OpenAI-compatible SDK. |
| |
| Useful for hitting OpenAI's hosted Responses API directly, or for an |
| enterprise proxy that fronts GPT-5 / o-series models behind the |
| Responses contract. Custom auth headers (e.g. APIM subscription |
| keys or other gateway tokens) can be passed via |
| $ANTHROPIC_CUSTOM_HEADERS or the ``extra_headers`` argument. The |
| bearer token comes from ``api_key`` or $OPENAI_API_KEY (proxies |
| that auth via a custom header can pass any non-empty placeholder). |
| """ |
| try: |
| from openai import OpenAI |
| except ImportError as e: |
| raise RuntimeError( |
| "openai package required for provider=openai_responses; " |
| "pip install openai" |
| ) from e |
| base_url = api_base or os.environ.get("OPENAI_RESPONSES_BASE_URL") |
| if not base_url: |
| raise RuntimeError( |
| "FrontierCEO(provider=openai_responses) needs api_base or " |
| "$OPENAI_RESPONSES_BASE_URL" |
| ) |
| headers = dict(extra_headers or {}) |
| env_headers = _parse_header_env(os.environ.get("ANTHROPIC_CUSTOM_HEADERS", "")) |
| for k, v in env_headers.items(): |
| headers.setdefault(k, v) |
| key = ( |
| api_key |
| or os.environ.get("OPENAI_API_KEY") |
| or "dummy" |
| ) |
| self._client = OpenAI( |
| base_url=base_url, |
| api_key=key, |
| default_headers=headers or None, |
| timeout=request_timeout_s, |
| ) |
| self._model = model or "gpt-5.4" |
|
|
| def act(self, obs, env=None, week=0): |
| if self._dual_head: |
| return self._act_dual(obs, env=env, week=week) |
| from prompts import build_chat, parse_response |
| hint = self._max_tokens if self._budget_hint else None |
| messages = build_chat(obs, token_budget=hint) |
| completion = self._call(messages, max_tokens=self._max_tokens) |
| action, tel = parse_response(completion, obs.inbox) |
| if not tel["parse_ok"]: |
| self.n_parse_errors += 1 |
| return action |
|
|
| def _act_dual(self, obs, env=None, week=0): |
| """Two-pass frontier inference: action JSON call + journal text call. |
| |
| Mirrors the trained DualHeadCEO wire format so frontier numbers are |
| directly comparable to SFT / GRPO dual-head adapters. |
| |
| When ``self._permissive`` is set, the action pass uses the permissive |
| system prompt that allows <thinking>…</thinking> deliberation before |
| the <action> block and gives concrete rogue + budget hints. Only the |
| system prompt changes; the user content, the parser, and the journal |
| pass are identical so numbers stay comparable to the strict variant. |
| """ |
| from prompts import ( |
| ACTION_SYSTEM_PROMPT_PERMISSIVE, |
| build_action_chat, build_journal_chat, |
| parse_response, parse_journal_response, |
| render_observation, |
| ) |
| |
| act_messages = build_action_chat(obs) |
| if self._permissive: |
| act_messages = [ |
| {"role": "system", "content": ACTION_SYSTEM_PROMPT_PERMISSIVE}, |
| {"role": "user", "content": render_observation(obs)}, |
| ] |
| act_text = self._call(act_messages, max_tokens=self._action_max_tokens) |
| action, tel = parse_response(act_text, obs.inbox) |
| if not tel["parse_ok"] and not tel["parse_partial"]: |
| self.n_parse_errors += 1 |
| |
| jrn_messages = build_journal_chat( |
| obs, action.decisions, action.budget_allocations, action.diligence_requests, |
| ) |
| jrn_text = self._call(jrn_messages, max_tokens=self._journal_max_tokens) |
| action.journal_entry = parse_journal_response(jrn_text) |
| return action |
|
|
| def _call(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str: |
| import time |
| max_tokens = max_tokens if max_tokens is not None else self._max_tokens |
| last_err: Optional[Exception] = None |
| for attempt in range(self._max_retries): |
| try: |
| if self._provider == "anthropic": |
| return self._call_anthropic(messages, max_tokens=max_tokens) |
| if self._provider == "openai_responses": |
| return self._call_openai_responses(messages, max_tokens=max_tokens) |
| return self._call_openai(messages, max_tokens=max_tokens) |
| except Exception as e: |
| last_err = e |
| if attempt + 1 < self._max_retries: |
| backoff = 2 ** attempt |
| print( |
| f"[{self.name}] api retry {attempt+1}/{self._max_retries} " |
| f"after {backoff}s: {type(e).__name__}: {str(e)[:140]}", |
| file=sys.stderr, |
| ) |
| time.sleep(backoff) |
| else: |
| self.n_api_errors += 1 |
| print( |
| f"[{self.name}] giving up after {self._max_retries} retries: {last_err}", |
| file=sys.stderr, |
| ) |
| return "" |
|
|
| def _call_openai(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str: |
| r = self._client.chat.completions.create( |
| model=self._model, |
| messages=messages, |
| max_tokens=max_tokens or self._max_tokens, |
| temperature=self._temperature, |
| ) |
| if r.usage is not None: |
| self.total_tokens += r.usage.total_tokens or 0 |
| self.total_prompt_tokens += r.usage.prompt_tokens or 0 |
| self.total_completion_tokens += r.usage.completion_tokens or 0 |
| return r.choices[0].message.content or "" |
|
|
| def _call_openai_responses(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str: |
| |
| |
| |
| kwargs: Dict[str, Any] = { |
| "model": self._model, |
| "input": messages, |
| "max_output_tokens": max_tokens or self._max_tokens, |
| } |
| |
| if self._temperature is not None and not self._model.lower().startswith("gpt-5"): |
| kwargs["temperature"] = self._temperature |
| r = self._client.responses.create(**kwargs) |
| usage = getattr(r, "usage", None) |
| if usage is not None: |
| in_tok = getattr(usage, "input_tokens", 0) or 0 |
| out_tok = getattr(usage, "output_tokens", 0) or 0 |
| self.total_prompt_tokens += in_tok |
| self.total_completion_tokens += out_tok |
| self.total_tokens += in_tok + out_tok |
| return r.output_text or "" |
|
|
| def _call_anthropic(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str: |
| |
| system_parts = [m["content"] for m in messages if m.get("role") == "system"] |
| turns = [m for m in messages if m.get("role") != "system"] |
| system = "\n\n".join(system_parts) if system_parts else None |
| r = self._client.messages.create( |
| model=self._model, |
| max_tokens=max_tokens or self._max_tokens, |
| temperature=self._temperature, |
| system=system, |
| messages=turns, |
| ) |
| usage = getattr(r, "usage", None) |
| if usage is not None: |
| in_tok = getattr(usage, "input_tokens", 0) or 0 |
| out_tok = getattr(usage, "output_tokens", 0) or 0 |
| self.total_prompt_tokens += in_tok |
| self.total_completion_tokens += out_tok |
| self.total_tokens += in_tok + out_tok |
| text_parts: List[str] = [] |
| for block in r.content: |
| if getattr(block, "type", "") == "text": |
| text_parts.append(block.text) |
| return "".join(text_parts) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class EpisodeResult: |
| policy: str |
| seed: int |
| total_reward: float |
| weekly_rewards: List[float] |
| final_cash_inr: float |
| revenue_qtd_inr: float |
| ebitda_qtd_inr: float |
| ebitda_margin_pct: float |
| min_cash_inr: float |
| rogues_total: int |
| rogues_caught: int |
| avg_stockout_pct: float |
| avg_nps: float |
| trace: List[Dict] = None |
|
|
|
|
| def run_one_episode( |
| policy: CEOPolicy, |
| seed: int, |
| collect_trace: bool = False, |
| verbose: bool = False, |
| ) -> EpisodeResult: |
| env = SimMartEnvironment() |
| obs = env.reset(seed=seed, episode_id=f"{policy.name}-{seed}") |
| weekly_rewards: List[float] = [] |
| stockouts: List[float] = [] |
| npss: List[float] = [] |
| trace: List[Dict] = [] if collect_trace else None |
|
|
| min_cash = env.state.company.cash_inr |
|
|
| for w in range(1, env.MAX_WEEKS + 1): |
| action = policy.act(obs, env=env, week=w) |
| step_obs = env.step(action) |
|
|
| r = step_obs.reward or 0.0 |
| weekly_rewards.append(r) |
| stockouts.append(step_obs.kpi_snapshot.stockout_rate_pct) |
| npss.append(step_obs.kpi_snapshot.nps) |
| min_cash = min(min_cash, env.state.company.cash_inr) |
|
|
| if collect_trace: |
| trace.append({ |
| "week": w, |
| "day": step_obs.day_of_quarter, |
| "inbox_size": len(obs.inbox), |
| "decisions": [d.model_dump() for d in action.decisions], |
| "budget_allocations": action.budget_allocations, |
| "journal": action.journal_entry, |
| "reward": r, |
| "kpi": step_obs.kpi_snapshot.model_dump(), |
| "active_crises": [c.crisis_id for c in step_obs.active_crises], |
| "pnl_qtd": step_obs.pnl_snapshot.model_dump(), |
| }) |
|
|
| if verbose: |
| print( |
| f" W{w:2d} r={r:+.3f} rev=₹{step_obs.kpi_snapshot.revenue_inr/1e7:.2f}Cr " |
| f"margin={step_obs.kpi_snapshot.gross_margin_pct:5.2f}% NPS={step_obs.kpi_snapshot.nps:4.1f} " |
| f"stockout={step_obs.kpi_snapshot.stockout_rate_pct:5.1f}% " |
| f"cash=₹{env.state.company.cash_inr/1e7:+.2f}Cr" |
| ) |
|
|
| obs = step_obs |
| if obs.done: |
| break |
|
|
| return EpisodeResult( |
| policy=policy.name, |
| seed=seed, |
| total_reward=sum(weekly_rewards), |
| weekly_rewards=weekly_rewards, |
| final_cash_inr=env.state.company.cash_inr, |
| revenue_qtd_inr=env.state.company.pnl_qtd.revenue_qtd_inr, |
| ebitda_qtd_inr=env.state.company.pnl_qtd.ebitda_qtd_inr, |
| ebitda_margin_pct=env.state.company.pnl_qtd.ebitda_margin_pct, |
| min_cash_inr=min_cash, |
| rogues_total=len(env.state.rogue_incidents), |
| rogues_caught=sum(1 for r in env.state.rogue_incidents if r.caught), |
| avg_stockout_pct=statistics.mean(stockouts) if stockouts else 0.0, |
| avg_nps=statistics.mean(npss) if npss else 0.0, |
| trace=trace, |
| ) |
|
|
|
|
| def run_policy( |
| policy: CEOPolicy, |
| seeds: List[int], |
| verbose: bool = False, |
| quiet: bool = False, |
| ) -> List[EpisodeResult]: |
| results: List[EpisodeResult] = [] |
| for seed in seeds: |
| if not quiet: |
| print(f"\n[{policy.name}] seed={seed} rollout …") |
| res = run_one_episode(policy, seed, collect_trace=False, verbose=verbose) |
| results.append(res) |
| if not quiet: |
| print( |
| f" → total reward {res.total_reward:+.3f}, " |
| f"EBITDA ₹{res.ebitda_qtd_inr/1e7:+.2f}Cr ({res.ebitda_margin_pct:+.1f}%), " |
| f"rogues {res.rogues_caught}/{res.rogues_total}, " |
| f"avg stockout {res.avg_stockout_pct:.1f}%" |
| ) |
| return results |
|
|
|
|
| def summarise(results_by_policy: Dict[str, List[EpisodeResult]]) -> None: |
| """Print a nicely-formatted comparison table.""" |
| header = ( |
| f"{'policy':<11} {'n':>3} " |
| f"{'tot_r (mean±sd)':>18} {'EBITDA% (mean)':>14} " |
| f"{'rogue_recall':>13} {'avg_stockout':>13} {'avg_NPS':>8}" |
| ) |
| print("\n" + header) |
| print("-" * len(header)) |
| for name in results_by_policy.keys(): |
| res = results_by_policy.get(name) |
| if not res: |
| continue |
| rewards = [r.total_reward for r in res] |
| mean_r = statistics.mean(rewards) |
| sd_r = statistics.stdev(rewards) if len(rewards) > 1 else 0.0 |
| ebitda_pct = statistics.mean([r.ebitda_margin_pct for r in res]) |
| recalls = [ |
| (r.rogues_caught / r.rogues_total) if r.rogues_total else 1.0 |
| for r in res |
| ] |
| recall = statistics.mean(recalls) |
| stockout = statistics.mean([r.avg_stockout_pct for r in res]) |
| nps = statistics.mean([r.avg_nps for r in res]) |
| print( |
| f"{name:<11} {len(res):>3} " |
| f"{mean_r:+7.3f} ± {sd_r:5.3f} " |
| f"{ebitda_pct:+13.2f} " |
| f"{recall:12.1%} " |
| f"{stockout:12.1f} " |
| f"{nps:7.1f}" |
| ) |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser(description="SimMart baseline CEO rollouts") |
| parser.add_argument("--policy", |
| choices=["random", "all_approve", "heuristic", "oracle", "god", "frontier", "all"], |
| default="all") |
| parser.add_argument("--frontier-provider", |
| choices=["auto", "openai", "anthropic", "openai_responses"], default="auto", |
| help="auto picks openai_responses for gpt-*, anthropic for claude-*, else openai") |
| parser.add_argument("--frontier-model", default=None, |
| help="Model id (default: Claude-Sonnet-4.6 for anthropic, " |
| "moonshotai/Kimi-K2.6 for openai)") |
| parser.add_argument("--frontier-api-base", default=None, |
| help="API base URL (default: $ANTHROPIC_BASE_URL for anthropic, " |
| "https://router.huggingface.co/v1 for openai)") |
| parser.add_argument("--frontier-temperature", type=float, default=0.0) |
| parser.add_argument("--frontier-max-tokens", type=int, |
| default=FrontierCEO.DEFAULT_MAX_TOKENS, |
| help="hard cap on response tokens; matches trainer MAX_NEW (600)") |
| parser.add_argument("--frontier-no-budget-hint", action="store_true", |
| help="do NOT tell the frontier model its token budget") |
| parser.add_argument("--seed", type=int, default=None, |
| help="Seed for a single-episode run (ignored if --episodes > 1)") |
| parser.add_argument("--episodes", type=int, default=3, |
| help="Number of episodes per policy; seeds = [1, 2, …, N]") |
| parser.add_argument("--verbose", action="store_true", help="Print per-week telemetry") |
| parser.add_argument("--quiet", action="store_true", help="Only print summary table") |
| parser.add_argument("--trace", type=str, default=None, |
| help="Dump a JSON trace of a single episode for analysis (uses --seed, single policy only)") |
| args = parser.parse_args() |
|
|
| |
| if args.trace: |
| if args.policy == "all": |
| print("--trace requires a single --policy", file=sys.stderr) |
| return 2 |
| policy_cls = { |
| "random": RandomCEO, "all_approve": AllApproveCEO, |
| "heuristic": HeuristicCEO, |
| "oracle": OracleCEO, "god": GodCEO, "frontier": FrontierCEO, |
| }[args.policy] |
| if args.policy == "random": |
| policy = policy_cls(seed=args.seed or 0) |
| elif args.policy == "frontier": |
| policy = FrontierCEO( |
| provider=args.frontier_provider, |
| model=args.frontier_model, |
| api_base=args.frontier_api_base, |
| temperature=args.frontier_temperature, |
| max_tokens=args.frontier_max_tokens, |
| budget_hint=not args.frontier_no_budget_hint, |
| ) |
| else: |
| policy = policy_cls() |
| seed = args.seed if args.seed is not None else 42 |
| print(f"[trace] {args.policy} seed={seed} → {args.trace}") |
| res = run_one_episode(policy, seed, collect_trace=True, verbose=args.verbose) |
| payload = { |
| "meta": { |
| "policy": res.policy, "seed": res.seed, |
| "total_reward": res.total_reward, |
| "final_cash_inr": res.final_cash_inr, |
| "ebitda_qtd_inr": res.ebitda_qtd_inr, |
| "ebitda_margin_pct": res.ebitda_margin_pct, |
| "rogues": {"total": res.rogues_total, "caught": res.rogues_caught}, |
| "avg_stockout_pct": res.avg_stockout_pct, |
| "avg_nps": res.avg_nps, |
| }, |
| "trace": res.trace, |
| } |
| with open(args.trace, "w") as f: |
| json.dump(payload, f, indent=2, default=str) |
| print(f" ✓ wrote {len(res.trace)} weekly steps") |
| return 0 |
|
|
| |
| seeds = [args.seed] if args.seed is not None and args.episodes == 1 \ |
| else list(range(1, args.episodes + 1)) |
|
|
| policies: List[CEOPolicy] = [] |
| if args.policy in ("random", "all"): |
| policies.append(RandomCEO(seed=0)) |
| if args.policy in ("all_approve", "all"): |
| policies.append(AllApproveCEO()) |
| if args.policy in ("heuristic", "all"): |
| policies.append(HeuristicCEO()) |
| if args.policy in ("oracle", "all"): |
| policies.append(OracleCEO()) |
| if args.policy in ("god", "all"): |
| policies.append(GodCEO()) |
| if args.policy == "frontier": |
| policies.append(FrontierCEO( |
| provider=args.frontier_provider, |
| model=args.frontier_model, |
| api_base=args.frontier_api_base, |
| temperature=args.frontier_temperature, |
| max_tokens=args.frontier_max_tokens, |
| budget_hint=not args.frontier_no_budget_hint, |
| )) |
|
|
| results_by_policy: Dict[str, List[EpisodeResult]] = {} |
| for p in policies: |
| results_by_policy[p.name] = run_policy(p, seeds, verbose=args.verbose, quiet=args.quiet) |
|
|
| summarise(results_by_policy) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|