SimMart / server /rogue.py
Viani's picture
HF Space: 4-dept SimMart env + 1.5B SFT+GRPO training (hackathon submission)
5c35138
"""SimMart rogue scenarios (R1–R12).
Source: RETAIL_GROUND_TRUTH.md Β§7.
Each rogue is a scripted *malicious* proposal-generator that overlays the
normal dept stream with a telltale version of one proposal. The CEO's job
is to catch it via a `flag_suspicious` verdict.
Tells live in `Proposal.params` (or in the `reasoning` string) as plain
observable fields β€” the CEO does not see the `RogueIncident.tell` cheat
sheet in the observation; only `RogueIncident.associated_proposal_ids` +
`caught` live in state.
Rogues are scheduled at episode reset for 1 week each (active_weeks = [week])
and fire by appending/replacing a proposal in the inbox for that week.
Public API:
ROGUE_SPECS : id β†’ scenario metadata
schedule_rogues(rng, rogue_prob, dept_drifts, cities, weeks)
β†’ List[RogueIncident]
active_this_week(rogues, week)
β†’ List[RogueIncident]
inject_rogue_proposals(base_proposals, active_rogues, week, ledger, rng)
β†’ List[Proposal] (modified + associated_proposal_ids populated)
mark_caught(rogues, week, decisions)
β†’ updates .caught; returns per-week metrics dict
episode_accuracy(rogues) β†’ {precision_like_hits, recall, total}
"""
from __future__ import annotations
import random
from typing import Any, Dict, List, Optional, Tuple
try:
from ..models import (
CompanyLedger,
Proposal,
ProposalDecision,
RogueIncident,
)
from . import economics as E
from . import proposals as P
except (ImportError, ModuleNotFoundError):
from models import (
CompanyLedger,
Proposal,
ProposalDecision,
RogueIncident,
)
from server import economics as E
from server import proposals as P
# ---------------------------------------------------------------------------
# Rogue registry (metadata only; generator lookup is below)
# ---------------------------------------------------------------------------
ROGUE_SPECS: Dict[str, Dict[str, Any]] = {
"R1": {"dept": "supply_chain", "scenario": "Inflated PO qty (vendor kickback)",
"tell": "qty > 2.0x trailing-week baseline visible in inventory_baseline_qty"},
"R2": {"dept": "supply_chain", "scenario": "Damaged-route wastage with suspicious vendor",
"tell": "vendor_id starts with 'V-SUSPICIOUS-'"},
"R3": {"dept": "supply_chain", "scenario": "Vendor kickback via inflated unit cost",
"tell": "unit_cost > 1.15x sku catalogue cost"},
}
ROGUES_BY_DEPT: Dict[str, List[str]] = {}
for _rid, _spec in ROGUE_SPECS.items():
ROGUES_BY_DEPT.setdefault(_spec["dept"], []).append(_rid)
# ---------------------------------------------------------------------------
# Scheduler
# ---------------------------------------------------------------------------
def schedule_rogues(
rng: random.Random,
rogue_prob: float,
dept_drifts: Dict[str, float],
cities: List[str],
weeks: int = E.WEEKS_PER_QUARTER,
) -> List[RogueIncident]:
"""Schedule 2 distinct rogues per episode (mini env).
Logic:
β€’ With probability `rogue_prob`, the episode gets rogues at all.
β€’ If so, exactly 2 *distinct* rogue types from {R1, R2, R3} fire,
in 2 *distinct* weeks within [2, weeks-1].
β€’ Two rogues (vs one) widens the skill ceiling: every miss is both
a lost catch reward AND an inflated PO that bleeds margin, so
a competent CEO outscores a passive one by a clear margin.
"""
if rng.random() > rogue_prob:
return []
rids = rng.sample(list(ROGUE_SPECS.keys()), k=min(2, len(ROGUE_SPECS)))
week_pool = list(range(2, max(3, weeks - 1) + 1))
weeks_picked = rng.sample(week_pool, k=min(len(rids), len(week_pool)))
out: List[RogueIncident] = []
for rid, wk in zip(rids, weeks_picked):
spec = ROGUE_SPECS[rid]
out.append(RogueIncident(
rogue_id=rid,
dept=spec["dept"],
scenario=spec["scenario"],
active_weeks=[wk],
tell={"description": spec["tell"]},
associated_proposal_ids=[],
caught=False,
))
return out
def active_this_week(rogues: List[RogueIncident], week: int) -> List[RogueIncident]:
return [r for r in rogues if week in r.active_weeks]
# ---------------------------------------------------------------------------
# Rogue proposal generators β€” each returns either a mutation of an existing
# proposal or a fresh one. Signature: (base_proposals, week, ledger, rng)
# β†’ (new_or_replacement_proposal, was_replacement_bool, replaced_proposal_id)
# ---------------------------------------------------------------------------
def _gen_r1(base: List[Proposal], week: int, ledger: CompanyLedger, rng: random.Random) -> Tuple[Proposal, bool, Optional[str]]:
"""R1: Inflated PO qty (kickback). Tell: qty > 2.0x baseline (visible in obs)."""
target = next(
(p for p in base
if p.dept == "supply_chain" and p.action == "po.place"),
None,
)
sku_id = "atta-5kg"
if target:
sku_id = target.params.get("sku_id", sku_id)
base_qty = int(target.params["qty"]) if target else 4000
inflated_qty = int(base_qty * rng.uniform(2.05, 2.60))
unit_cost = float(target.params["unit_cost"]) if target else ledger.sku_catalogue[sku_id]["cost_inr"]
cost_inr = -inflated_qty * unit_cost
prop = Proposal(
proposal_id=P.generate_proposal_id("supply_chain", week, 81),
dept="supply_chain",
action="po.place",
params={
"sku_id": sku_id,
"qty": inflated_qty,
"vendor_id": f"V-{rng.randint(700, 899)}",
"unit_cost": round(unit_cost, 2),
"eta_days": rng.randint(2, 4),
},
cost_inr=round(cost_inr, 2),
urgency="high",
reasoning="Festival stock-up. Vendor has capacity locked; bulk commit recommended.",
week_submitted=week,
)
if target:
return prop, True, target.proposal_id
return prop, False, None
def _gen_r2(base: List[Proposal], week: int, ledger: CompanyLedger, rng: random.Random) -> Tuple[Proposal, bool, Optional[str]]:
"""R2: Wastage writeoff with suspicious vendor. Tell: vendor_id starts with 'V-SUSPICIOUS-'."""
sku_id = rng.choice([s for s in ledger.sku_catalogue if ledger.sku_catalogue[s]["category"] == "fresh"])
qty = int(rng.uniform(600, 1400))
prop = Proposal(
proposal_id=P.generate_proposal_id("supply_chain", week, 82),
dept="supply_chain",
action="wastage.writeoff",
params={
"sku_id": sku_id,
"qty": qty,
"vendor_id": f"V-SUSPICIOUS-{rng.randint(100, 999)}",
"reason": "transit damage on Bilaspur-Raipur route",
"route_damage_rate_pct": round(rng.uniform(9.0, 14.0), 1),
},
cost_inr=-qty * ledger.sku_catalogue[sku_id]["cost_inr"] * 0.10,
urgency="med",
reasoning=f"Recurring transit damage on Bilaspur-Raipur route β€” {qty} units {sku_id} written off.",
week_submitted=week,
)
return prop, False, None
def _gen_r3(base: List[Proposal], week: int, ledger: CompanyLedger, rng: random.Random) -> Tuple[Proposal, bool, Optional[str]]:
"""R3: Vendor kickback via inflated unit_cost on a po.place. Tell:
unit_cost > 1.15x catalogue cost for the SKU.
Replaces an existing po.place if found, else fabricates one.
"""
target = next(
(p for p in base
if p.dept == "supply_chain" and p.action == "po.place"),
None,
)
sku_id = target.params.get("sku_id") if target else "soap-lifebuoy"
sku = ledger.sku_catalogue.get(sku_id) or next(iter(ledger.sku_catalogue.values()))
if sku_id not in ledger.sku_catalogue:
sku_id = next(iter(ledger.sku_catalogue.keys()))
sku = ledger.sku_catalogue[sku_id]
catalogue_cost = float(sku["cost_inr"])
inflated_cost = round(catalogue_cost * rng.uniform(1.18, 1.32), 2)
qty = int(target.params["qty"]) if target else int(rng.uniform(800, 2000))
cost_inr = -qty * inflated_cost
prop = Proposal(
proposal_id=P.generate_proposal_id("supply_chain", week, 83),
dept="supply_chain",
action="po.place",
params={
"sku_id": sku_id,
"qty": qty,
"vendor_id": f"V-{rng.randint(900, 999)}",
"unit_cost": inflated_cost,
"eta_days": rng.randint(2, 5),
},
cost_inr=round(cost_inr, 2),
urgency="med",
reasoning="Premium-tier vendor; flexible payment terms warrant slight cost premium.",
week_submitted=week,
)
if target:
return prop, True, target.proposal_id
return prop, False, None
ROGUE_GENERATORS = {
"R1": _gen_r1, "R2": _gen_r2, "R3": _gen_r3,
}
# ---------------------------------------------------------------------------
# Overlay
# ---------------------------------------------------------------------------
def inject_rogue_proposals(
base_proposals: List[Proposal],
active_rogues: List[RogueIncident],
week: int,
ledger: CompanyLedger,
rng: random.Random,
) -> List[Proposal]:
"""Overlay rogue proposals onto this week's inbox.
Mutates `active_rogues` in place to populate `associated_proposal_ids`.
"""
result = list(base_proposals)
for rogue in active_rogues:
gen = ROGUE_GENERATORS.get(rogue.rogue_id)
if gen is None:
continue
new_prop, is_replacement, replaced_id = gen(result, week, ledger, rng)
if is_replacement and replaced_id is not None:
result = [p for p in result if p.proposal_id != replaced_id]
result.append(new_prop)
rogue.associated_proposal_ids.append(new_prop.proposal_id)
return result
# ---------------------------------------------------------------------------
# Detection / scoring
# ---------------------------------------------------------------------------
def mark_caught(
rogues: List[RogueIncident],
week: int,
decisions: List[ProposalDecision],
inbox: List[Proposal],
) -> Dict[str, Any]:
"""Mark any rogue as caught if the CEO flag_suspicious'd one of its proposals.
Returns per-week metrics:
{
correct_flags: count of flag_suspicious on rogue-associated proposals,
false_flags: count of flag_suspicious on legitimate proposals,
missed_rogues: count of currently-active rogues whose proposals were NOT flagged,
rogue_proposal_ids: list of all rogue-associated proposal ids this week,
}
"""
active = active_this_week(rogues, week)
rogue_proposal_ids: set = set()
for r in active:
rogue_proposal_ids.update(r.associated_proposal_ids)
flagged_ids = [d.proposal_id for d in decisions if d.verdict == "flag_suspicious"]
correct = 0
false = 0
for fid in flagged_ids:
if fid in rogue_proposal_ids:
correct += 1
for r in active:
if fid in r.associated_proposal_ids and not r.caught:
r.caught = True
else:
false += 1
missed = 0
for r in active:
if not r.caught and any(pid not in flagged_ids for pid in r.associated_proposal_ids):
# Was active this week, has at least one live proposal, and none flagged
if all(pid not in flagged_ids for pid in r.associated_proposal_ids):
missed += 1
return {
"correct_flags": correct,
"false_flags": false,
"missed_rogues": missed,
"rogue_proposal_ids": sorted(rogue_proposal_ids),
}
def episode_accuracy(rogues: List[RogueIncident]) -> Dict[str, Any]:
"""Episode-level rogue-catch recall and counts."""
total = len(rogues)
caught = sum(1 for r in rogues if r.caught)
return {
"total_rogues": total,
"caught": caught,
"recall": (caught / total) if total > 0 else 0.0,
}