SimMart / inference.py
Viani's picture
inference: rename internal `amd` provider to generic `openai_responses`
745336f
"""Baseline CEO policies for SimMart — random, heuristic, oracle.
Used for:
• Smoke-testing the environment end-to-end before GRPO training
• Establishing a reward floor (random), ceiling (oracle), and a sensible
non-learned target (heuristic) that the trained Qwen CEO must beat
• Generating the "before vs after training" contrast trace for the demo
Usage:
python inference.py --policy heuristic --seed 42 --episodes 3 --verbose
python inference.py --policy all --episodes 10 --quiet
python inference.py --trace hero_episode.json --policy oracle --seed 7
"""
from __future__ import annotations
import argparse
import json
import os
import random
import statistics
import sys
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
# Make repo importable no matter where we run from
HERE = os.path.dirname(os.path.abspath(__file__))
if HERE not in sys.path:
sys.path.insert(0, HERE)
from models import (
CrisisEvent,
Proposal,
ProposalDecision,
SimMartAction,
SimMartObservation,
)
from server.environment import SimMartEnvironment
# ---------------------------------------------------------------------------
# Policy base
# ---------------------------------------------------------------------------
class CEOPolicy:
"""Interface: given an obs, return a SimMartAction for this week.
Policies may inspect `env.state` (oracle) or not (random/heuristic).
"""
name: str = "base"
def act(
self,
obs: SimMartObservation,
env: Optional[SimMartEnvironment] = None,
week: int = 0,
) -> SimMartAction:
raise NotImplementedError
# ---------------------------------------------------------------------------
# Random CEO
# ---------------------------------------------------------------------------
class RandomCEO(CEOPolicy):
name = "random"
def __init__(self, seed: int = 0):
self._rng = random.Random(seed)
def act(self, obs, env=None, week=0):
verdicts = ["approve", "reject", "flag_suspicious"]
weights = [0.65, 0.25, 0.10]
decisions = [
ProposalDecision(
proposal_id=p.proposal_id,
verdict=self._rng.choices(verdicts, weights=weights)[0],
)
for p in obs.inbox
]
return SimMartAction(
decisions=decisions,
budget_allocations={"supply_chain": 5e6, "store_ops": 5e5, "finance": 1e6, "growth": 1e6},
journal_entry=f"Random CEO, week {week}. No pattern.",
)
# ---------------------------------------------------------------------------
# All-Approve CEO — baseline that approves everything (the GRPO step-80 hack).
# Used to measure the EV gap between blind-approve and rule-based policies.
# ---------------------------------------------------------------------------
class AllApproveCEO(CEOPolicy):
name = "all_approve"
def act(self, obs, env=None, week=0):
decisions = [
ProposalDecision(proposal_id=p.proposal_id, verdict="approve")
for p in obs.inbox
]
return SimMartAction(
decisions=decisions,
budget_allocations={"supply_chain": 1e7, "store_ops": 2e6, "finance": 1e6, "growth": 2e6},
journal_entry=f"All-approve CEO, week {week}.",
)
# ---------------------------------------------------------------------------
# Heuristic CEO — hand-coded but sensible; reads only the public obs
# ---------------------------------------------------------------------------
class HeuristicCEO(CEOPolicy):
name = "heuristic"
def __init__(self, cash_floor_inr: float = 3e7):
self._cash_floor = cash_floor_inr
@staticmethod
def _crisis_flags(active: List[CrisisEvent]) -> Dict[str, bool]:
f: Dict[str, bool] = {}
for c in active:
f[c.crisis_id] = True
return f
def act(self, obs, env=None, week=0):
decisions: List[ProposalDecision] = []
cash = obs.kpi_snapshot.cash_inr
tight_cash = cash < self._cash_floor
crisis_active = self._crisis_flags(obs.active_crises)
for p in obs.inbox:
verdict = self._decide(p, tight_cash, crisis_active)
decisions.append(ProposalDecision(proposal_id=p.proposal_id, verdict=verdict))
# Budget allocation: lean into supply chain when not cash-tight
if tight_cash:
budget = {"supply_chain": 5e6, "store_ops": 5e5, "finance": 5e5, "growth": 5e5}
else:
budget = {"supply_chain": 1e7, "store_ops": 2e6, "finance": 1e6, "growth": 2e6}
journal = self._write_journal(obs, decisions, tight_cash, week)
return SimMartAction(decisions=decisions, budget_allocations=budget, journal_entry=journal)
# Per-action spend ceilings beyond which the heuristic rejects (4dept-reject
# env: opex-bearing actions hit EBITDA directly, so high-spend low-ROI
# actions are net-negative-EV and must be rejected).
SPEND_CEILINGS_INR: Dict[str, float] = {
"campaign.launch": 15_00_000, # spend > ₹15L → revenue uplift can't justify
"capex.approve": 10_00_000, # capex has no in-env revenue model → reject most
"return.approve": 50_000, # batch refunds > ₹50k usually wasteful
"brand.ambassador": 5_00_000, # NPS uplift +1.5 caps benefit
"loyalty.update": 5_00_000, # NPS uplift +0.8 caps benefit
"hours.extend": 50_000, # tiny revenue mult vs payroll
}
@classmethod
def _decide(
cls,
p: Proposal,
tight_cash: bool,
crisis_active: Dict[str, bool],
) -> str:
"""Rule-based verdict.
Reliable rogue tell from the obs alone:
R2 ghost vendor : vendor_id starts with 'V-SUSPICIOUS-'
Other rogue tells (R1 qty inflation, R3 cost inflation) need
baselines not in the obs — Oracle gets those via env.state.
Negative-EV reject rules (4dept-reject env):
campaign/capex/refund/loyalty/ambassador/hours.extend get
rejected when their spend exceeds a per-action ceiling, since
the opex hit to EBITDA outweighs the modeled benefit.
"""
params = p.params or {}
vendor_id = str(params.get("vendor_id", "") or "")
if vendor_id.upper().startswith("V-SUSPICIOUS"):
return "flag_suspicious"
ceiling = cls.SPEND_CEILINGS_INR.get(p.action)
if ceiling is not None:
spend = float(
params.get("spend_inr")
or params.get("amount_inr")
or params.get("refund_inr")
or params.get("perk_cost_inr")
or params.get("cost_inr")
or (params.get("hours", 0) * 60 * 100 if p.action == "hours.extend" else 0)
or 0.0
)
if spend > ceiling:
return "reject"
if p.urgency == "high":
return "approve"
if p.urgency == "low" and tight_cash:
return "reject"
return "approve"
@staticmethod
def _write_journal(obs, decisions, tight_cash, week):
n_flag = sum(1 for d in decisions if d.verdict == "flag_suspicious")
n_reject = sum(1 for d in decisions if d.verdict == "reject")
n_approve = sum(1 for d in decisions if d.verdict == "approve")
kpi = obs.kpi_snapshot
active = [f"{c.crisis_id}:{c.name}" for c in obs.active_crises]
return (
f"Week {week}: {n_approve} approved, {n_reject} rejected, {n_flag} flagged. "
f"Cash ₹{kpi.cash_inr/1e7:+.2f} Cr (cash-tight={tight_cash}), "
f"NPS {kpi.nps:.0f}, stockout {kpi.stockout_rate_pct:.1f}%, "
f"SLA {kpi.delivery_sla_hit_rate_pct:.0f}%. "
f"Active crises: {', '.join(active) or 'none'}. "
f"Next week: watch cash and supply-chain replenishment."
)
# ---------------------------------------------------------------------------
# Oracle CEO — peeks at env.state.rogue_incidents to flag rogues perfectly
# ---------------------------------------------------------------------------
class OracleCEO(HeuristicCEO):
"""Behaves like the heuristic but with two cheats:
1. Flags every rogue-associated proposal this week (via env.state)
2. Drives up approvals when NOT tight-cash (to avoid false rejects)
"""
name = "oracle"
def act(self, obs, env=None, week=0):
if env is None:
return super().act(obs, env, week)
# Look up this week's rogue proposal ids
rogue_proposal_ids = set()
for r in env.state.rogue_incidents:
if week in r.active_weeks:
rogue_proposal_ids.update(r.associated_proposal_ids)
action = super().act(obs, env, week)
patched: List[ProposalDecision] = []
for d in action.decisions:
if d.proposal_id in rogue_proposal_ids:
patched.append(ProposalDecision(
proposal_id=d.proposal_id,
verdict="flag_suspicious",
flag_reason="Rogue indicator: telltale pattern in proposal metadata.",
))
else:
patched.append(d)
action.decisions = patched
action.journal_entry = action.journal_entry + " (Oracle override: flagged known rogues)."
return action
# ---------------------------------------------------------------------------
# God CEO — empirical ceiling. Reads ground truth, approves-all non-rogues,
# perfect-flags every rogue (zero false flags), writes a journal engineered
# for max journal_coherence_score, and allocates budget to mitigate stockouts.
#
# Not a learned baseline — used only to answer "what is the max achievable
# total_reward on this env?". See grader.py for the 7 reward components.
# ---------------------------------------------------------------------------
class GodCEO(CEOPolicy):
name = "rule_oracle"
def __init__(self, cash_floor_inr: float = 3e7):
self._cash_floor = cash_floor_inr
self._last_journal: str = ""
@staticmethod
def _active_rogue_proposal_ids(env: SimMartEnvironment, week: int) -> set:
ids: set = set()
for r in env.state.rogue_incidents:
if week in r.active_weeks:
ids.update(r.associated_proposal_ids)
return ids
def act(self, obs, env=None, week=0):
if env is None:
return HeuristicCEO().act(obs, env, week)
rogue_ids = self._active_rogue_proposal_ids(env, week)
cash = obs.kpi_snapshot.cash_inr
tight_cash = cash < self._cash_floor
crisis_active = {c.crisis_id: True for c in obs.active_crises}
decisions: List[ProposalDecision] = []
for p in obs.inbox:
if p.proposal_id in rogue_ids:
# Perfect rogue detection → rogue_catch = +1.0, zero false flags
decisions.append(ProposalDecision(
proposal_id=p.proposal_id,
verdict="flag_suspicious",
flag_reason="Rogue indicator: telltale pattern in proposal metadata.",
))
continue
# Non-rogue: reject negative-EV opex spends; approve everything
# else (PO restocks drive the kpi_delta component). Verdict set:
# {approve, reject, flag_suspicious}.
verdict = "approve"
params = p.params or {}
ceiling = HeuristicCEO.SPEND_CEILINGS_INR.get(p.action)
if ceiling is not None:
spend = float(
params.get("spend_inr")
or params.get("amount_inr")
or params.get("refund_inr")
or params.get("perk_cost_inr")
or params.get("cost_inr")
or (params.get("hours", 0) * 60 * 100 if p.action == "hours.extend" else 0)
or 0.0
)
if spend > ceiling:
verdict = "reject"
if verdict == "approve" and p.urgency == "low" and tight_cash:
verdict = "reject"
decisions.append(ProposalDecision(proposal_id=p.proposal_id, verdict=verdict))
# Budget: lean hard into supply_chain (kills stockouts, helps KPI delta).
# Cash-tight → throttle.
if tight_cash:
budget = {"supply_chain": 6e6, "store_ops": 1e6, "finance": 5e5, "growth": 5e5}
else:
budget = {"supply_chain": 1.2e7, "store_ops": 2.5e6, "finance": 1e6, "growth": 2e6}
journal = self._write_max_coherence_journal(
obs, decisions, rogue_ids, tight_cash, week
)
self._last_journal = journal
return SimMartAction(
decisions=decisions,
budget_allocations=budget,
journal_entry=journal,
)
def _write_max_coherence_journal(
self,
obs,
decisions: List[ProposalDecision],
rogue_ids: set,
tight_cash: bool,
week: int,
) -> str:
"""Free-text retrospective journal for the 2-dept mini env.
Mentions both depts, keeps continuity with last week, and references
proposal IDs so a future judge can trace the decision rationale.
"""
approved_or_flagged = [
d.proposal_id for d in decisions
if d.verdict in ("approve", "flag_suspicious")
]
top_pids = approved_or_flagged[:4]
flagged = [d.proposal_id for d in decisions if d.verdict == "flag_suspicious"]
flagged_str = flagged[0] if flagged else ""
kpi = obs.kpi_snapshot
active = [f"{c.crisis_id}:{c.name}" for c in obs.active_crises] or ["none"]
# Continuity anchor: echo a phrase from last week if we have one
anchor = "SimMart operating cadence"
if self._last_journal:
for tok in ("SimMart", "supply", "store", "margin", "stockout"):
if tok.lower() in self._last_journal.lower():
anchor = tok
break
body = (
f"Week {week} CEO decision log for SimMart, continuing last week's focus on {anchor}. "
f"Priority this week was the supply_chain pipeline and store_ops SLA. "
f"Approvals ({len([d for d in decisions if d.verdict=='approve'])}): "
f"cleared {', '.join(top_pids[:3]) or 'the routine POs'} to keep stock cover healthy "
f"and drive next-week revenue. "
f"Flagged ({len(flagged)}): {flagged_str or 'no rogue signal this week'} — "
f"risk pattern matched our rogue playbook (inflated PO qty / ghost vendor / unit-cost kickback). "
f"Rejections: only low-urgency items when cash was tight. "
f"KPI pulse: cash ₹{kpi.cash_inr/1e7:+.2f} Cr (tight={tight_cash}), "
f"NPS {kpi.nps:.0f}{kpi.nps_delta:+.1f}), "
f"stockout {kpi.stockout_rate_pct:.1f}% (Δ{kpi.stockout_delta_pts:+.2f} pp), "
f"margin {kpi.gross_margin_pct:.2f}% (Δ{kpi.margin_delta_pts:+.2f} pp). "
f"Action items: reinforce supply_chain lead-time discipline and store_ops staffing for the festive week; "
f"keep cash runway above ₹{self._cash_floor/1e7:.0f} Cr. "
f"Crises in play: {', '.join(active)}. "
f"Decision rationale: approve legitimate work, flag_suspicious only where the rogue tell is explicit. "
f"Next week: monitor stockout trajectory, audit flagged proposal "
f"{flagged_str or (top_pids[-1] if top_pids else 'OPS')}, "
f"tighten journal continuity with {anchor}."
)
return body
# ---------------------------------------------------------------------------
# DualHeadCEO — two-pass inference on a Qwen LoRA setup with separate
# action + journal adapters. Architecture:
#
# base model (Qwen 2.5 1.5B, bf16)
# ├─ adapter "action" — GRPO-trained (or SFT init); 300 tok output
# └─ adapter "journal" — SFT-trained on GodCEO journals; 400 tok output
#
# Each weekly step fires two forward passes: action first, then journal (which
# sees the decisions the action head just produced, so it can reference their
# IDs for the coherence bonus). This breaks the "one bad journal kills all
# reward components" cascade from GRPO v5 and gives clean per-head credit
# assignment in the RL loop.
# ---------------------------------------------------------------------------
class DualHeadCEO(CEOPolicy):
"""Two-pass CEO over a single base model with swappable LoRA adapters.
Construct with a preloaded HF/PEFT model that has two adapters registered
(via ``model.load_adapter(path, adapter_name=name)``) and this class
handles the per-step swap + generation.
Example::
model, tok = FastLanguageModel.from_pretrained(...)
model.load_adapter(action_adapter_path, adapter_name="action")
model.load_adapter(journal_adapter_path, adapter_name="journal")
FastLanguageModel.for_inference(model)
ceo = DualHeadCEO(model, tok)
"""
name = "dual"
def __init__(
self,
model: Any,
tokenizer: Any,
action_adapter: str = "action",
journal_adapter: str = "journal",
action_max_tokens: int = 300,
journal_max_tokens: int = 400,
do_sample: bool = False,
temperature: float = 1.0,
verbose: bool = False,
):
self.model = model
self.tokenizer = tokenizer
self.action_adapter = action_adapter
self.journal_adapter = journal_adapter
self.action_max_tokens = action_max_tokens
self.journal_max_tokens = journal_max_tokens
self.do_sample = do_sample
self.temperature = temperature
self.verbose = verbose
self.n_action_parse_err = 0
self.n_journal_parse_err = 0
self.t_action_s = 0.0
self.t_journal_s = 0.0
self.n_action_tokens = 0
self.n_journal_tokens = 0
# Lazy import of prompts to keep this module free of torch deps for the
# baseline-only code path.
def _prompts(self):
import prompts as P
return P
def _generate(self, chat, adapter, max_new):
"""Generate one completion; returns (text, n_new_tokens, wall_s)."""
import time
import torch
self.model.set_adapter(adapter)
prompt = self.tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True,
)
enc = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
t0 = time.time()
with torch.inference_mode():
out = self.model.generate(
**enc,
max_new_tokens=max_new,
do_sample=self.do_sample,
temperature=self.temperature,
pad_token_id=self.tokenizer.pad_token_id,
)
wall = time.time() - t0
n_new = int(out.shape[1] - enc.input_ids.shape[1])
text = self.tokenizer.decode(
out[0, enc.input_ids.shape[1]:], skip_special_tokens=True,
)
return text, n_new, wall
def act(
self,
obs: SimMartObservation,
env: Optional[SimMartEnvironment] = None,
week: int = 0,
) -> SimMartAction:
P = self._prompts()
# --- Pass 1: action head -> decisions + budget ------------------------
action_chat = P.build_action_chat(obs)
action_text, n_a, t_a = self._generate(
action_chat, self.action_adapter, self.action_max_tokens,
)
self.t_action_s += t_a
self.n_action_tokens += n_a
# Use the robust parser (request_info fallback). We explicitly ignore
# any journal_entry the action head might have leaked into its JSON —
# the journal pass owns that field.
action, tel = P.parse_response(action_text, obs.inbox)
if not tel["parse_ok"] and not tel["parse_partial"]:
self.n_action_parse_err += 1
# --- Pass 2: journal head -> free-form retrospective ------------------
journal_chat = P.build_journal_chat(
obs, action.decisions, action.budget_allocations, action.diligence_requests,
)
journal_text, n_j, t_j = self._generate(
journal_chat, self.journal_adapter, self.journal_max_tokens,
)
self.t_journal_s += t_j
self.n_journal_tokens += n_j
journal = P.parse_journal_response(journal_text)
if not journal:
self.n_journal_parse_err += 1
if self.verbose:
tok_s_a = n_a / t_a if t_a > 0 else 0
tok_s_j = n_j / t_j if t_j > 0 else 0
wk = week if week else 0
print(
f" W{wk:>2} act={n_a}tok/{t_a:.1f}s ({tok_s_a:.0f}t/s) "
f"jrn={n_j}tok/{t_j:.1f}s ({tok_s_j:.0f}t/s) "
f"parse_ok={tel['parse_ok']}",
flush=True,
)
action.journal_entry = journal
return action
# ---------------------------------------------------------------------------
# Frontier CEO — backed by a hosted frontier LLM
# provider="openai" → any OpenAI-compatible Chat Completions endpoint
# (HF router, Groq, Together, vLLM-served, …)
# provider="anthropic" → Anthropic Messages API (direct, Bedrock-via-proxy,
# corporate APIM proxy, etc.)
# provider="openai_responses" → OpenAI Responses API behind an OpenAI-compatible
# SDK (e.g. an enterprise APIM proxy that fronts
# GPT-5 / o-series via the Responses contract)
# ---------------------------------------------------------------------------
def _parse_header_env(raw: str) -> Dict[str, str]:
"""Parse ANTHROPIC_CUSTOM_HEADERS which uses `Name:Value` per line."""
out: Dict[str, str] = {}
for line in (raw or "").replace("\r\n", "\n").split("\n"):
line = line.strip()
if ":" in line:
k, v = line.split(":", 1)
out[k.strip()] = v.strip()
return out
class FrontierCEO(CEOPolicy):
"""CEO policy backed by a hosted frontier LLM.
Uses the *exact same* prompt and JSON parser as the trained Qwen 1.5B,
so the reward it achieves is an honest "what a big model sees when
handed the same observation" ceiling.
Does NOT read env.state — only the public observation. This keeps it
at the same information level as the heuristic / trained policies.
"""
# Match the trainer's per-step generation budget so every policy — trained
# or frontier — gets the same hard cap on response length. See
# scripts/launch_grpo_ddp.sh (MAX_NEW) and the env's truncation behaviour.
DEFAULT_MAX_TOKENS: int = 600
def __init__(
self,
model: Optional[str] = None,
provider: str = "auto", # "auto" | "openai" | "anthropic" | "openai_responses"
api_base: Optional[str] = None,
api_key: Optional[str] = None,
auth_token: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
max_tokens: Optional[int] = None,
temperature: float = 0.0,
max_retries: int = 3,
request_timeout_s: float = 90.0,
budget_hint: bool = True,
dual_head: bool = False,
action_max_tokens: int = 300,
journal_max_tokens: int = 400,
permissive: bool = False,
):
if max_tokens is None:
max_tokens = self.DEFAULT_MAX_TOKENS
provider = self._resolve_provider(provider, api_base, model)
self._provider = provider
if provider == "anthropic":
self._init_anthropic(
model=model,
api_base=api_base,
api_key=api_key,
auth_token=auth_token,
extra_headers=extra_headers,
request_timeout_s=request_timeout_s,
)
elif provider == "openai":
self._init_openai(
model=model,
api_base=api_base,
api_key=api_key,
request_timeout_s=request_timeout_s,
)
elif provider == "openai_responses":
self._init_openai_responses(
model=model,
api_base=api_base,
api_key=api_key,
extra_headers=extra_headers,
request_timeout_s=request_timeout_s,
)
else:
raise ValueError(f"Unknown provider: {provider!r}")
self._max_tokens = max_tokens
self._temperature = temperature
self._max_retries = max_retries
self._budget_hint = budget_hint
self._dual_head = dual_head
self._action_max_tokens = action_max_tokens
self._journal_max_tokens = journal_max_tokens
self._permissive = permissive
tag = f"frontier:{self._model.split('/')[-1]}"
if dual_head:
tag += "-dual"
if permissive:
tag += "-permissive"
self.name = tag
self.n_parse_errors = 0
self.n_api_errors = 0
self.total_tokens = 0
self.total_prompt_tokens = 0
self.total_completion_tokens = 0
@staticmethod
def _resolve_provider(
provider: str, api_base: Optional[str], model: Optional[str]
) -> str:
if provider != "auto":
return provider
if api_base and "router.huggingface.co" in api_base:
return "openai"
if model and ("claude" in model.lower()):
return "anthropic"
# GPT-ish model name → OpenAI Responses API endpoint
if model and (model.lower().startswith("gpt-") or model.lower().startswith("o")):
return "openai_responses"
if os.environ.get("ANTHROPIC_BASE_URL"):
return "anthropic"
if os.environ.get("OPENAI_RESPONSES_BASE_URL"):
return "openai_responses"
if os.environ.get("HF_token") or os.environ.get("HF_TOKEN"):
return "openai"
raise RuntimeError(
"FrontierCEO provider=auto: cannot infer. "
"Pass provider=openai|anthropic|openai_responses explicitly."
)
def _init_anthropic(
self,
model: Optional[str],
api_base: Optional[str],
api_key: Optional[str],
auth_token: Optional[str],
extra_headers: Optional[Dict[str, str]],
request_timeout_s: float,
) -> None:
try:
import anthropic
except ImportError as e:
raise RuntimeError(
"anthropic package required for provider=anthropic; pip install anthropic"
) from e
base_url = api_base or os.environ.get("ANTHROPIC_BASE_URL")
if not base_url:
raise RuntimeError(
"FrontierCEO(provider=anthropic) needs api_base or $ANTHROPIC_BASE_URL"
)
token = (
auth_token
or api_key
or os.environ.get("ANTHROPIC_AUTH_TOKEN")
or os.environ.get("ANTHROPIC_API_KEY")
)
if not token:
raise RuntimeError(
"FrontierCEO(provider=anthropic) needs auth_token or $ANTHROPIC_AUTH_TOKEN"
)
headers = dict(extra_headers or {})
env_headers = _parse_header_env(os.environ.get("ANTHROPIC_CUSTOM_HEADERS", ""))
for k, v in env_headers.items():
headers.setdefault(k, v)
self._client = anthropic.Anthropic(
base_url=base_url,
auth_token=token,
default_headers=headers or None,
timeout=request_timeout_s,
)
self._model = (
model
or os.environ.get("ANTHROPIC_DEFAULT_SONNET_MODEL")
or os.environ.get("ANTHROPIC_MODEL")
or "Claude-Sonnet-4.6"
)
def _init_openai(
self,
model: Optional[str],
api_base: Optional[str],
api_key: Optional[str],
request_timeout_s: float,
) -> None:
try:
from openai import OpenAI
except ImportError as e:
raise RuntimeError(
"openai package required for provider=openai; pip install openai"
) from e
key = api_key or os.environ.get("HF_token") or os.environ.get("HF_TOKEN")
if not key:
raise RuntimeError(
"FrontierCEO(provider=openai) needs api_key or $HF_token"
)
self._client = OpenAI(
base_url=api_base or "https://router.huggingface.co/v1",
api_key=key,
timeout=request_timeout_s,
)
self._model = model or "moonshotai/Kimi-K2.6"
def _init_openai_responses(
self,
model: Optional[str],
api_base: Optional[str],
api_key: Optional[str],
extra_headers: Optional[Dict[str, str]],
request_timeout_s: float,
) -> None:
"""OpenAI Responses API via an OpenAI-compatible SDK.
Useful for hitting OpenAI's hosted Responses API directly, or for an
enterprise proxy that fronts GPT-5 / o-series models behind the
Responses contract. Custom auth headers (e.g. APIM subscription
keys or other gateway tokens) can be passed via
$ANTHROPIC_CUSTOM_HEADERS or the ``extra_headers`` argument. The
bearer token comes from ``api_key`` or $OPENAI_API_KEY (proxies
that auth via a custom header can pass any non-empty placeholder).
"""
try:
from openai import OpenAI
except ImportError as e:
raise RuntimeError(
"openai package required for provider=openai_responses; "
"pip install openai"
) from e
base_url = api_base or os.environ.get("OPENAI_RESPONSES_BASE_URL")
if not base_url:
raise RuntimeError(
"FrontierCEO(provider=openai_responses) needs api_base or "
"$OPENAI_RESPONSES_BASE_URL"
)
headers = dict(extra_headers or {})
env_headers = _parse_header_env(os.environ.get("ANTHROPIC_CUSTOM_HEADERS", ""))
for k, v in env_headers.items():
headers.setdefault(k, v)
key = (
api_key
or os.environ.get("OPENAI_API_KEY")
or "dummy"
)
self._client = OpenAI(
base_url=base_url,
api_key=key,
default_headers=headers or None,
timeout=request_timeout_s,
)
self._model = model or "gpt-5.4"
def act(self, obs, env=None, week=0):
if self._dual_head:
return self._act_dual(obs, env=env, week=week)
from prompts import build_chat, parse_response
hint = self._max_tokens if self._budget_hint else None
messages = build_chat(obs, token_budget=hint)
completion = self._call(messages, max_tokens=self._max_tokens)
action, tel = parse_response(completion, obs.inbox)
if not tel["parse_ok"]:
self.n_parse_errors += 1
return action
def _act_dual(self, obs, env=None, week=0):
"""Two-pass frontier inference: action JSON call + journal text call.
Mirrors the trained DualHeadCEO wire format so frontier numbers are
directly comparable to SFT / GRPO dual-head adapters.
When ``self._permissive`` is set, the action pass uses the permissive
system prompt that allows <thinking>…</thinking> deliberation before
the <action> block and gives concrete rogue + budget hints. Only the
system prompt changes; the user content, the parser, and the journal
pass are identical so numbers stay comparable to the strict variant.
"""
from prompts import (
ACTION_SYSTEM_PROMPT_PERMISSIVE,
build_action_chat, build_journal_chat,
parse_response, parse_journal_response,
render_observation,
)
# --- Pass 1: action ---------------------------------------------------
act_messages = build_action_chat(obs)
if self._permissive:
act_messages = [
{"role": "system", "content": ACTION_SYSTEM_PROMPT_PERMISSIVE},
{"role": "user", "content": render_observation(obs)},
]
act_text = self._call(act_messages, max_tokens=self._action_max_tokens)
action, tel = parse_response(act_text, obs.inbox)
if not tel["parse_ok"] and not tel["parse_partial"]:
self.n_parse_errors += 1
# --- Pass 2: journal --------------------------------------------------
jrn_messages = build_journal_chat(
obs, action.decisions, action.budget_allocations, action.diligence_requests,
)
jrn_text = self._call(jrn_messages, max_tokens=self._journal_max_tokens)
action.journal_entry = parse_journal_response(jrn_text)
return action
def _call(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str:
import time
max_tokens = max_tokens if max_tokens is not None else self._max_tokens
last_err: Optional[Exception] = None
for attempt in range(self._max_retries):
try:
if self._provider == "anthropic":
return self._call_anthropic(messages, max_tokens=max_tokens)
if self._provider == "openai_responses":
return self._call_openai_responses(messages, max_tokens=max_tokens)
return self._call_openai(messages, max_tokens=max_tokens)
except Exception as e:
last_err = e
if attempt + 1 < self._max_retries:
backoff = 2 ** attempt
print(
f"[{self.name}] api retry {attempt+1}/{self._max_retries} "
f"after {backoff}s: {type(e).__name__}: {str(e)[:140]}",
file=sys.stderr,
)
time.sleep(backoff)
else:
self.n_api_errors += 1
print(
f"[{self.name}] giving up after {self._max_retries} retries: {last_err}",
file=sys.stderr,
)
return ""
def _call_openai(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str:
r = self._client.chat.completions.create(
model=self._model,
messages=messages,
max_tokens=max_tokens or self._max_tokens,
temperature=self._temperature,
)
if r.usage is not None:
self.total_tokens += r.usage.total_tokens or 0
self.total_prompt_tokens += r.usage.prompt_tokens or 0
self.total_completion_tokens += r.usage.completion_tokens or 0
return r.choices[0].message.content or ""
def _call_openai_responses(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str:
# OpenAI Responses API. System prompt can travel inside `input` — we
# pass the full messages list unchanged and rely on `output_text` to
# flatten the result back to a string.
kwargs: Dict[str, Any] = {
"model": self._model,
"input": messages,
"max_output_tokens": max_tokens or self._max_tokens,
}
# gpt-5.* reasoning models reject a non-default temperature
if self._temperature is not None and not self._model.lower().startswith("gpt-5"):
kwargs["temperature"] = self._temperature
r = self._client.responses.create(**kwargs)
usage = getattr(r, "usage", None)
if usage is not None:
in_tok = getattr(usage, "input_tokens", 0) or 0
out_tok = getattr(usage, "output_tokens", 0) or 0
self.total_prompt_tokens += in_tok
self.total_completion_tokens += out_tok
self.total_tokens += in_tok + out_tok
return r.output_text or ""
def _call_anthropic(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str:
# Anthropic wants system prompt separate from the user/assistant turns.
system_parts = [m["content"] for m in messages if m.get("role") == "system"]
turns = [m for m in messages if m.get("role") != "system"]
system = "\n\n".join(system_parts) if system_parts else None
r = self._client.messages.create(
model=self._model,
max_tokens=max_tokens or self._max_tokens,
temperature=self._temperature,
system=system,
messages=turns,
)
usage = getattr(r, "usage", None)
if usage is not None:
in_tok = getattr(usage, "input_tokens", 0) or 0
out_tok = getattr(usage, "output_tokens", 0) or 0
self.total_prompt_tokens += in_tok
self.total_completion_tokens += out_tok
self.total_tokens += in_tok + out_tok
text_parts: List[str] = []
for block in r.content:
if getattr(block, "type", "") == "text":
text_parts.append(block.text)
return "".join(text_parts)
# ---------------------------------------------------------------------------
# Rollout runner
# ---------------------------------------------------------------------------
@dataclass
class EpisodeResult:
policy: str
seed: int
total_reward: float
weekly_rewards: List[float]
final_cash_inr: float
revenue_qtd_inr: float
ebitda_qtd_inr: float
ebitda_margin_pct: float
min_cash_inr: float
rogues_total: int
rogues_caught: int
avg_stockout_pct: float
avg_nps: float
trace: List[Dict] = None
def run_one_episode(
policy: CEOPolicy,
seed: int,
collect_trace: bool = False,
verbose: bool = False,
) -> EpisodeResult:
env = SimMartEnvironment()
obs = env.reset(seed=seed, episode_id=f"{policy.name}-{seed}")
weekly_rewards: List[float] = []
stockouts: List[float] = []
npss: List[float] = []
trace: List[Dict] = [] if collect_trace else None
min_cash = env.state.company.cash_inr
for w in range(1, env.MAX_WEEKS + 1):
action = policy.act(obs, env=env, week=w)
step_obs = env.step(action)
r = step_obs.reward or 0.0
weekly_rewards.append(r)
stockouts.append(step_obs.kpi_snapshot.stockout_rate_pct)
npss.append(step_obs.kpi_snapshot.nps)
min_cash = min(min_cash, env.state.company.cash_inr)
if collect_trace:
trace.append({
"week": w,
"day": step_obs.day_of_quarter,
"inbox_size": len(obs.inbox),
"decisions": [d.model_dump() for d in action.decisions],
"budget_allocations": action.budget_allocations,
"journal": action.journal_entry,
"reward": r,
"kpi": step_obs.kpi_snapshot.model_dump(),
"active_crises": [c.crisis_id for c in step_obs.active_crises],
"pnl_qtd": step_obs.pnl_snapshot.model_dump(),
})
if verbose:
print(
f" W{w:2d} r={r:+.3f} rev=₹{step_obs.kpi_snapshot.revenue_inr/1e7:.2f}Cr "
f"margin={step_obs.kpi_snapshot.gross_margin_pct:5.2f}% NPS={step_obs.kpi_snapshot.nps:4.1f} "
f"stockout={step_obs.kpi_snapshot.stockout_rate_pct:5.1f}% "
f"cash=₹{env.state.company.cash_inr/1e7:+.2f}Cr"
)
obs = step_obs
if obs.done:
break
return EpisodeResult(
policy=policy.name,
seed=seed,
total_reward=sum(weekly_rewards),
weekly_rewards=weekly_rewards,
final_cash_inr=env.state.company.cash_inr,
revenue_qtd_inr=env.state.company.pnl_qtd.revenue_qtd_inr,
ebitda_qtd_inr=env.state.company.pnl_qtd.ebitda_qtd_inr,
ebitda_margin_pct=env.state.company.pnl_qtd.ebitda_margin_pct,
min_cash_inr=min_cash,
rogues_total=len(env.state.rogue_incidents),
rogues_caught=sum(1 for r in env.state.rogue_incidents if r.caught),
avg_stockout_pct=statistics.mean(stockouts) if stockouts else 0.0,
avg_nps=statistics.mean(npss) if npss else 0.0,
trace=trace,
)
def run_policy(
policy: CEOPolicy,
seeds: List[int],
verbose: bool = False,
quiet: bool = False,
) -> List[EpisodeResult]:
results: List[EpisodeResult] = []
for seed in seeds:
if not quiet:
print(f"\n[{policy.name}] seed={seed} rollout …")
res = run_one_episode(policy, seed, collect_trace=False, verbose=verbose)
results.append(res)
if not quiet:
print(
f" → total reward {res.total_reward:+.3f}, "
f"EBITDA ₹{res.ebitda_qtd_inr/1e7:+.2f}Cr ({res.ebitda_margin_pct:+.1f}%), "
f"rogues {res.rogues_caught}/{res.rogues_total}, "
f"avg stockout {res.avg_stockout_pct:.1f}%"
)
return results
def summarise(results_by_policy: Dict[str, List[EpisodeResult]]) -> None:
"""Print a nicely-formatted comparison table."""
header = (
f"{'policy':<11} {'n':>3} "
f"{'tot_r (mean±sd)':>18} {'EBITDA% (mean)':>14} "
f"{'rogue_recall':>13} {'avg_stockout':>13} {'avg_NPS':>8}"
)
print("\n" + header)
print("-" * len(header))
for name in results_by_policy.keys():
res = results_by_policy.get(name)
if not res:
continue
rewards = [r.total_reward for r in res]
mean_r = statistics.mean(rewards)
sd_r = statistics.stdev(rewards) if len(rewards) > 1 else 0.0
ebitda_pct = statistics.mean([r.ebitda_margin_pct for r in res])
recalls = [
(r.rogues_caught / r.rogues_total) if r.rogues_total else 1.0
for r in res
]
recall = statistics.mean(recalls)
stockout = statistics.mean([r.avg_stockout_pct for r in res])
nps = statistics.mean([r.avg_nps for r in res])
print(
f"{name:<11} {len(res):>3} "
f"{mean_r:+7.3f} ± {sd_r:5.3f} "
f"{ebitda_pct:+13.2f} "
f"{recall:12.1%} "
f"{stockout:12.1f} "
f"{nps:7.1f}"
)
def main() -> int:
parser = argparse.ArgumentParser(description="SimMart baseline CEO rollouts")
parser.add_argument("--policy",
choices=["random", "all_approve", "heuristic", "oracle", "god", "frontier", "all"],
default="all")
parser.add_argument("--frontier-provider",
choices=["auto", "openai", "anthropic", "openai_responses"], default="auto",
help="auto picks openai_responses for gpt-*, anthropic for claude-*, else openai")
parser.add_argument("--frontier-model", default=None,
help="Model id (default: Claude-Sonnet-4.6 for anthropic, "
"moonshotai/Kimi-K2.6 for openai)")
parser.add_argument("--frontier-api-base", default=None,
help="API base URL (default: $ANTHROPIC_BASE_URL for anthropic, "
"https://router.huggingface.co/v1 for openai)")
parser.add_argument("--frontier-temperature", type=float, default=0.0)
parser.add_argument("--frontier-max-tokens", type=int,
default=FrontierCEO.DEFAULT_MAX_TOKENS,
help="hard cap on response tokens; matches trainer MAX_NEW (600)")
parser.add_argument("--frontier-no-budget-hint", action="store_true",
help="do NOT tell the frontier model its token budget")
parser.add_argument("--seed", type=int, default=None,
help="Seed for a single-episode run (ignored if --episodes > 1)")
parser.add_argument("--episodes", type=int, default=3,
help="Number of episodes per policy; seeds = [1, 2, …, N]")
parser.add_argument("--verbose", action="store_true", help="Print per-week telemetry")
parser.add_argument("--quiet", action="store_true", help="Only print summary table")
parser.add_argument("--trace", type=str, default=None,
help="Dump a JSON trace of a single episode for analysis (uses --seed, single policy only)")
args = parser.parse_args()
# --- Trace mode: single episode, JSON dump ---
if args.trace:
if args.policy == "all":
print("--trace requires a single --policy", file=sys.stderr)
return 2
policy_cls = {
"random": RandomCEO, "all_approve": AllApproveCEO,
"heuristic": HeuristicCEO,
"oracle": OracleCEO, "god": GodCEO, "frontier": FrontierCEO,
}[args.policy]
if args.policy == "random":
policy = policy_cls(seed=args.seed or 0)
elif args.policy == "frontier":
policy = FrontierCEO(
provider=args.frontier_provider,
model=args.frontier_model,
api_base=args.frontier_api_base,
temperature=args.frontier_temperature,
max_tokens=args.frontier_max_tokens,
budget_hint=not args.frontier_no_budget_hint,
)
else:
policy = policy_cls()
seed = args.seed if args.seed is not None else 42
print(f"[trace] {args.policy} seed={seed}{args.trace}")
res = run_one_episode(policy, seed, collect_trace=True, verbose=args.verbose)
payload = {
"meta": {
"policy": res.policy, "seed": res.seed,
"total_reward": res.total_reward,
"final_cash_inr": res.final_cash_inr,
"ebitda_qtd_inr": res.ebitda_qtd_inr,
"ebitda_margin_pct": res.ebitda_margin_pct,
"rogues": {"total": res.rogues_total, "caught": res.rogues_caught},
"avg_stockout_pct": res.avg_stockout_pct,
"avg_nps": res.avg_nps,
},
"trace": res.trace,
}
with open(args.trace, "w") as f:
json.dump(payload, f, indent=2, default=str)
print(f" ✓ wrote {len(res.trace)} weekly steps")
return 0
# --- Sweep mode ---
seeds = [args.seed] if args.seed is not None and args.episodes == 1 \
else list(range(1, args.episodes + 1))
policies: List[CEOPolicy] = []
if args.policy in ("random", "all"):
policies.append(RandomCEO(seed=0))
if args.policy in ("all_approve", "all"):
policies.append(AllApproveCEO())
if args.policy in ("heuristic", "all"):
policies.append(HeuristicCEO())
if args.policy in ("oracle", "all"):
policies.append(OracleCEO())
if args.policy in ("god", "all"):
policies.append(GodCEO())
if args.policy == "frontier":
policies.append(FrontierCEO(
provider=args.frontier_provider,
model=args.frontier_model,
api_base=args.frontier_api_base,
temperature=args.frontier_temperature,
max_tokens=args.frontier_max_tokens,
budget_hint=not args.frontier_no_budget_hint,
))
results_by_policy: Dict[str, List[EpisodeResult]] = {}
for p in policies:
results_by_policy[p.name] = run_policy(p, seeds, verbose=args.verbose, quiet=args.quiet)
summarise(results_by_policy)
return 0
if __name__ == "__main__":
raise SystemExit(main())