SimMart / tests /test_environment.py
Viani's picture
HF Space: 4-dept SimMart env + 1.5B SFT+GRPO training (hackathon submission)
5c35138
"""Integration tests for SimMartEnvironment.
Verifies invariants that must hold for any episode regardless of CEO policy:
• reset returns a valid week-1 observation with 4–6 inbox items
• step advances week by 1 and day by up to 7
• episode terminates exactly at week 13 with done=True
• all returned obs/state pydantic-validate cleanly
Also runs an approve-all 13-week episode for 3 seeds to ensure stability
(no crashes, reward always in [-5, +5], P&L doesn't explode).
Run locally: `python -m unittest tests.test_environment`
"""
from __future__ import annotations
import os
import sys
import unittest
HERE = os.path.dirname(os.path.abspath(__file__))
ROOT = os.path.dirname(HERE)
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
from models import ExecutiveDiligenceRequest, ProposalDecision, SimMartAction, SimMartObservation
from server.environment import SimMartEnvironment
def _approve_all_action(obs: SimMartObservation, week: int) -> SimMartAction:
return SimMartAction(
decisions=[
ProposalDecision(proposal_id=p.proposal_id, verdict="approve")
for p in obs.inbox
],
budget_allocations={
"supply_chain": 1e7, "store_ops": 1e6,
},
journal_entry=f"Week {week}: approved all.",
)
class EnvContractTests(unittest.TestCase):
def test_reset_returns_valid_observation(self):
env = SimMartEnvironment()
obs = env.reset(seed=42, episode_id="contract-1")
self.assertFalse(obs.done)
self.assertIsNone(obs.reward)
self.assertEqual(obs.week_of_quarter, 1)
self.assertGreaterEqual(len(obs.inbox), 6)
self.assertLessEqual(len(obs.inbox), 12)
self.assertEqual(len(obs.schema_hash), 8)
self.assertIsInstance(obs, SimMartObservation)
def test_each_step_advances_one_week(self):
env = SimMartEnvironment()
obs = env.reset(seed=42)
prev_week = obs.week_of_quarter
for _ in range(3):
obs = env.step(_approve_all_action(obs, prev_week))
self.assertEqual(obs.week_of_quarter, prev_week + 1)
self.assertLessEqual(obs.day_of_quarter, 7 * (prev_week + 1))
prev_week = obs.week_of_quarter
def test_episode_terminates_at_week_8(self):
env = SimMartEnvironment()
obs = env.reset(seed=42)
for w in range(1, 9):
obs = env.step(_approve_all_action(obs, w))
if obs.done:
self.assertEqual(w, 8, f"Terminated early at week {w}")
break
self.assertTrue(obs.done)
self.assertIsNotNone(obs.reward)
def test_episode_no_crash_across_seeds(self):
for seed in (7, 42, 101):
with self.subTest(seed=seed):
env = SimMartEnvironment()
obs = env.reset(seed=seed, episode_id=f"stability-{seed}")
total = 0.0
for w in range(1, 9):
obs = env.step(_approve_all_action(obs, w))
total += obs.reward or 0.0
self.assertTrue(-5.0 <= (obs.reward or 0.0) <= 5.0,
f"weekly reward out of bounds at W{w}: {obs.reward}")
self.assertTrue(obs.done)
# History should have 8 entries
self.assertEqual(len(env.state.history), 8)
# P&L should stay within an order of magnitude of baseline
rev = env.state.company.pnl_qtd.revenue_qtd_inr
self.assertGreater(rev, 0.0)
self.assertLess(rev, 2e9) # < ₹200 Cr
class InboxAndStateHygieneTests(unittest.TestCase):
def test_inbox_fields_are_populated(self):
env = SimMartEnvironment()
obs = env.reset(seed=42)
for p in obs.inbox:
self.assertTrue(p.proposal_id)
self.assertIn(p.dept, {"supply_chain", "store_ops", "finance", "growth"})
self.assertTrue(p.action)
self.assertIsInstance(p.params, dict)
self.assertIn(p.urgency, {"low", "med", "high"})
def test_state_records_history_after_each_step(self):
env = SimMartEnvironment()
obs = env.reset(seed=42)
for w in range(1, 6):
obs = env.step(_approve_all_action(obs, w))
self.assertEqual(len(env.state.history), w)
last = env.state.history[-1]
self.assertEqual(last.week, w)
self.assertIsNotNone(last.kpi_snapshot)
def test_rogues_are_hidden_from_observation(self):
"""Observation must never leak RogueIncident.tell / scenario."""
env = SimMartEnvironment()
obs = env.reset(seed=42)
# RogueIncident lives on state only
self.assertFalse(hasattr(obs, "rogue_incidents"))
self.assertFalse(hasattr(obs, "dept_drifts"))
class ExecutiveDiligenceTests(unittest.TestCase):
def test_reset_surfaces_full_dashboard_and_diligence_budget(self):
env = SimMartEnvironment()
obs = env.reset(seed=42)
self.assertIsNotNone(obs.kpi_snapshot)
self.assertIsNotNone(obs.pnl_snapshot)
self.assertEqual(obs.diligence_budget_remaining, 2)
self.assertEqual(obs.executive_diligence_findings, [])
def test_diligence_request_generates_next_week_finding_and_cost(self):
env = SimMartEnvironment()
obs = env.reset(seed=42, episode_id="diligence-1")
target = obs.inbox[0]
action = _approve_all_action(obs, 1)
action.diligence_requests = [
ExecutiveDiligenceRequest(
request_type="vendor_audit",
proposal_id=target.proposal_id,
rationale="CEO wants audit-level scrutiny on a costly proposal",
)
]
next_obs = env.step(action)
self.assertEqual(len(next_obs.executive_diligence_findings), 1)
finding = next_obs.executive_diligence_findings[0]
self.assertEqual(finding.proposal_id, target.proposal_id)
self.assertEqual(finding.request_type, "vendor_audit")
self.assertEqual(finding.status, "completed")
self.assertGreater(finding.cost_inr, 0)
self.assertGreater(env.state.company.pnl_qtd.opex_qtd_inr, finding.cost_inr)
self.assertEqual(len(env.state.history[-1].diligence_requests), 1)
self.assertEqual(len(env.state.history[-1].diligence_findings), 1)
def test_diligence_budget_caps_completed_requests(self):
env = SimMartEnvironment()
obs = env.reset(seed=42, episode_id="diligence-2")
action = _approve_all_action(obs, 1)
action.diligence_requests = [
ExecutiveDiligenceRequest(
request_type="vendor_audit",
proposal_id=p.proposal_id,
rationale="batch audit",
)
for p in obs.inbox[:3]
]
next_obs = env.step(action)
statuses = [f.status for f in next_obs.executive_diligence_findings]
self.assertEqual(statuses.count("completed"), 2)
self.assertEqual(statuses.count("capacity_exceeded"), 1)
class UnknownSkuHardeningTests(unittest.TestCase):
"""Regression tests for the v6b crash mode where a CEO modify (or upstream
LLM-rewritten proposal) supplied a sku_id absent from SKU_CATALOGUE, which
then crashed the next day's apply_shrinkage_and_spoilage tick."""
@staticmethod
def _modify_to_unknown_sku(obs: SimMartObservation, week: int, bad_sku: str = "shampoo-100ml") -> SimMartAction:
decisions = []
for p in obs.inbox:
if "sku_id" in p.params and p.action in ("po.place", "po.bulk_deal", "wastage.writeoff"):
decisions.append(ProposalDecision(
proposal_id=p.proposal_id,
verdict="modify",
modified_params={"sku_id": bad_sku},
))
else:
decisions.append(ProposalDecision(proposal_id=p.proposal_id, verdict="approve"))
return SimMartAction(
decisions=decisions,
budget_allocations={
"supply_chain": 1e7, "store_ops": 1e6,
},
journal_entry=f"Week {week}: hardening test.",
)
def test_modify_with_unknown_sku_does_not_crash(self):
env = SimMartEnvironment()
obs = env.reset(seed=42, episode_id="hardening-1")
# Step 1: inject unknown sku via modify on every sku-bearing proposal.
obs = env.step(self._modify_to_unknown_sku(obs, week=1))
self.assertIsNotNone(obs.reward)
# Steps 2–4: continue with approve-all so the daily shrinkage tick runs
# over the inventory state created in step 1. This is the original crash site.
for w in range(2, 5):
obs = env.step(_approve_all_action(obs, w))
self.assertIsNotNone(obs.reward)
def test_full_episode_with_unknown_sku_modifies(self):
"""Every step modifies sku-bearing proposals to an unknown sku; episode
must run cleanly to terminal week 8."""
env = SimMartEnvironment()
obs = env.reset(seed=7, episode_id="hardening-2")
for w in range(1, 9):
obs = env.step(self._modify_to_unknown_sku(obs, w))
self.assertTrue(-5.0 <= (obs.reward or 0.0) <= 5.0,
f"weekly reward out of bounds at W{w}: {obs.reward}")
if obs.done:
break
self.assertTrue(obs.done)
self.assertEqual(len(env.state.history), 8)
def test_unknown_sku_does_not_create_phantom_inventory(self):
"""After a modify with an unknown sku_id, ledger.inventory must not
contain that sku_id — neither created at write nor surviving the next tick."""
env = SimMartEnvironment()
obs = env.reset(seed=99, episode_id="hardening-3")
obs = env.step(self._modify_to_unknown_sku(obs, week=1, bad_sku="shampoo-100ml"))
self.assertNotIn("shampoo-100ml", env.state.company.inventory)
# And the catalogue is still pristine
self.assertNotIn("shampoo-100ml", env.state.company.sku_catalogue)
if __name__ == "__main__":
unittest.main(verbosity=2)