procure-env / server /environment.py
prabhjot27's picture
refactor: humanize codebase for hackathon credibility
00f0ff4
import uuid
import copy
from typing import Optional
from models import (
ProcureObservation, ProcureState, SupplierVisible,
QueryAction, RequestDocAction, OfferAction, AcceptAction, RejectAction
)
from tasks import TASKS
# --- Reward weights for the terminal score breakdown ---
COST_EFFICIENCY_WEIGHT = 0.40
CERT_COMPLIANCE_WEIGHT = 0.30
QUALITY_CHECK_WEIGHT = 0.20
DUE_DILIGENCE_WEIGHT = 0.10
# Deception penalty: accepting a supplier who ran a bait-and-switch multiplies
# the total terminal score by this factor. 0.4 is intentionally harsh -- in real
# procurement, getting locked into a deceptive supplier costs time and money to unwind.
DECEPTION_PENALTY_MULTIPLIER = 0.40
# Quality threshold below which a supplier should be rejected.
# Matches real ISO inspection acceptance criteria (~60% pass rate floor).
QUALITY_THRESHOLD = 0.60
# Score bounds: the OpenEnv validator rejects exactly 0.0 and 1.0.
# Clamp all final scores to this open interval.
SCORE_MIN = 0.001
SCORE_MAX = 0.999
# Per-step reward for revealing new information.
REWARD_QUERY_NEW = 0.01 # query field not previously known
REWARD_DOC_NEW = 0.03 # request_doc not previously seen
REWARD_ISSUE_FOUND = 0.05 # quality < threshold or missing required cert discovered
REWARD_DECEPTIVE_BAIT = 0.04 # deceptive supplier "accepts" during negotiation (bait)
class ProcureEnvironment:
"""
Core environment logic for one procurement episode.
One instance per WebSocket session. All mutable state lives on self --
there is no shared state between sessions. The environment is reset
explicitly via reset(), not at construction time, so the same instance
can be reused across episodes (the server currently creates a new instance
per session, but this pattern supports reuse).
Episode lifecycle:
1. reset(task_id) -- loads task, initialises supplier pool, returns obs
2. step(action) -- processes one agent action, returns next obs + reward
3. Repeat until obs.done == True (accept, all-rejected, or step budget exhausted)
"""
def __init__(self, task_id: str = "task1_easy"):
self.task_id = task_id
self._task = TASKS[task_id]
self._episode_id: Optional[str] = None
self._suppliers: list[dict] = []
self._rejected_ids: set[int] = set()
self._accepted_id: Optional[int] = None
self._revealed: dict[str, dict] = {}
self._best_offers: dict[str, float] = {}
self._step_count: int = 0
self._cumulative_reward: float = 0.0
self._done: bool = False
self._deceptive_trap_triggered: dict[int, bool] = {}
# ------------------------------------------------------------------ #
# Public API #
# ------------------------------------------------------------------ #
def reset(self) -> ProcureObservation:
self._episode_id = str(uuid.uuid4())[:8]
self._suppliers = copy.deepcopy(self._task["suppliers"])
self._rejected_ids = set()
self._accepted_id = None
self._revealed = {str(s["id"]): {} for s in self._suppliers}
self._best_offers = {str(s["id"]): s["quoted_price"] for s in self._suppliers}
self._step_count = 0
self._cumulative_reward = 0.0
self._done = False
self._deceptive_trap_triggered = {s["id"]: False for s in self._suppliers}
rfq = self._task["rfq"]
supplier_names = ", ".join(s["name"] for s in self._suppliers)
certs_note = (
f" Required certifications: {rfq['required_certs']}."
if rfq["required_certs"] else ""
)
msg = (
f"RFQ: {rfq['item']} -- {rfq['quantity']} units, "
f"budget ₹{rfq['budget']:,.0f}, deadline {rfq['deadline_days']} days.{certs_note} "
f"Suppliers: {supplier_names}. "
f"Step budget: {self._task['max_steps']} steps. "
f"Verify certifications and quality before accepting."
)
return self._build_observation(reward=0.0, message=msg)
def step(self, action: dict) -> ProcureObservation:
if self._done:
return self._build_observation(reward=0.0, message="Episode already done. Call reset().")
self._step_count += 1
action_type = action.get("action")
reward = 0.0
message = ""
if action_type == "query":
reward, message = self._handle_query(action)
elif action_type == "request_doc":
reward, message = self._handle_request_doc(action)
elif action_type == "offer":
reward, message = self._handle_offer(action)
elif action_type == "accept":
reward, message = self._handle_accept(action)
elif action_type == "reject":
reward, message = self._handle_reject(action)
else:
message = f"Unknown action '{action_type}'. Valid: query, request_doc, offer, accept, reject."
reward = -0.02
steps_remaining = self._task["max_steps"] - self._step_count
if steps_remaining <= 0 and not self._done:
self._done = True
message += " TIMEOUT: step budget exhausted without completing procurement."
self._cumulative_reward += reward
return self._build_observation(reward=reward, message=message)
@property
def state(self) -> ProcureState:
return ProcureState(
task_id=self.task_id,
episode_id=self._episode_id or "",
step_count=self._step_count,
done=self._done,
cumulative_reward=self._cumulative_reward,
accepted_supplier_id=self._accepted_id,
suppliers_hidden=self._suppliers
)
# ------------------------------------------------------------------ #
# Action Handlers #
# ------------------------------------------------------------------ #
def _handle_query(self, action: dict) -> tuple[float, str]:
sid = action.get("supplier_id")
field = action.get("field")
supplier = self._get_supplier(sid)
if not supplier:
return -0.01, f"Supplier ID {sid} not found."
if sid in self._rejected_ids:
return -0.01, f"{supplier['name']} has already been rejected."
field_map = {
"lead_time": ("lead_time_days", "lead time (days)"),
"moq": ("moq", "minimum order quantity"),
"reliability": ("reliability", "reliability score"),
}
if field not in field_map:
return -0.01, f"Unknown field '{field}'. Valid: lead_time, moq, reliability."
internal_key, display = field_map[field]
already_known = field in self._revealed[str(sid)]
value = supplier[internal_key]
self._revealed[str(sid)][field] = value
name = supplier["name"]
if already_known:
return 0.0, f"{name}: {display} = {value} (already on record)."
# Add context when the value is decision-relevant
note = ""
if field == "lead_time" and self._task["rfq"].get("deadline_days"):
deadline = self._task["rfq"]["deadline_days"]
if value > deadline:
note = f" That exceeds your {deadline}-day deadline -- worth flagging."
if field == "reliability" and isinstance(value, float) and value < 0.80:
note = " Reliability below 0.80 -- verify carefully before committing."
return REWARD_QUERY_NEW, f"{name}: {display} = {value}.{note}"
def _handle_request_doc(self, action: dict) -> tuple[float, str]:
sid = action.get("supplier_id")
doc_type = action.get("doc_type")
supplier = self._get_supplier(sid)
if not supplier:
return -0.01, f"Supplier ID {sid} not found."
if sid in self._rejected_ids:
return -0.01, f"{supplier['name']} has already been rejected."
doc_map = {
"quality_report": ("quality_score", "quality score"),
"certifications": ("certifications", "certifications"),
"financial_stability": ("reliability", "financial stability / reliability"),
}
if doc_type not in doc_map:
return -0.01, f"Unknown doc_type '{doc_type}'. Valid: quality_report, certifications, financial_stability."
internal_key, display = doc_map[doc_type]
already_known = doc_type in self._revealed[str(sid)]
value = supplier[internal_key]
self._revealed[str(sid)][doc_type] = value
name = supplier["name"]
if already_known:
return 0.0, f"{name} {display}: {value} (already on record)."
extra_msg = ""
reward = REWARD_DOC_NEW
if doc_type == "quality_report" and isinstance(value, float):
grade = (
"Excellent" if value >= 0.85 else
"Good" if value >= 0.70 else
"Acceptable" if value >= QUALITY_THRESHOLD else
"Below threshold"
)
extra_msg = f" ({grade})"
if value < QUALITY_THRESHOLD:
extra_msg += f" -- below the {QUALITY_THRESHOLD} floor. Recommend rejecting."
reward = REWARD_ISSUE_FOUND
if doc_type == "certifications":
required = self._task["rfq"]["required_certs"]
missing = [c for c in required if c not in value]
if missing:
extra_msg = f" WARNING: missing required certs: {missing}. This supplier cannot be accepted."
reward = REWARD_ISSUE_FOUND
return reward, f"{name} {display}: {value}.{extra_msg}"
def _handle_offer(self, action: dict) -> tuple[float, str]:
sid = action.get("supplier_id")
offered_price = action.get("price")
supplier = self._get_supplier(sid)
if not supplier:
return -0.01, f"Supplier ID {sid} not found."
if sid in self._rejected_ids:
return -0.01, f"{supplier['name']} has already been rejected."
if sid == self._accepted_id:
return -0.01, f"{supplier['name']} is already accepted."
name = supplier["name"]
min_price = supplier["min_price"]
behavior = supplier["behavior"]
current_price = self._best_offers[str(sid)]
if offered_price >= current_price:
return -0.01, (
f"{name} noted ₹{offered_price:,.0f}/unit but their current price is "
f"₹{current_price:,.0f}. Offer below their current price to negotiate."
)
if behavior == "deceptive":
# Deceptive supplier accepts anything during negotiation -- trap springs on accept().
# Record that the bait was taken so accept() can fire the revision.
self._best_offers[str(sid)] = offered_price
self._deceptive_trap_triggered[sid] = True
return REWARD_DECEPTIVE_BAIT, (
f"{name} confirmed ₹{offered_price:,.0f}/unit. "
f"Use accept to finalise -- verify their reliability score first."
)
if offered_price < min_price:
# Offered below floor: supplier counters at their minimum.
self._best_offers[str(sid)] = min_price
improvement = (current_price - min_price) / supplier["quoted_price"]
reward = improvement * 0.3
return reward, (
f"{name} cannot go below ₹{min_price:,.0f}/unit -- that's their floor. "
f"Countered at ₹{min_price:,.0f}."
)
if behavior == "flexible":
self._best_offers[str(sid)] = offered_price
improvement = (current_price - offered_price) / supplier["quoted_price"]
reward = improvement * 0.5
return reward, (
f"{name} accepted ₹{offered_price:,.0f}/unit. "
f"Price locked in. Use accept to finalise."
)
if behavior == "firm":
# Firm supplier splits the difference once, then holds.
counter = round((offered_price + min_price) / 2, 0)
self._best_offers[str(sid)] = counter
improvement = (current_price - counter) / supplier["quoted_price"]
reward = improvement * 0.3
return reward, (
f"{name} came down to ₹{counter:,.0f}/unit -- that's their best offer."
)
return 0.0, f"Offer to {name} processed."
def _handle_accept(self, action: dict) -> tuple[float, str]:
sid = action.get("supplier_id")
supplier = self._get_supplier(sid)
if not supplier:
return -0.01, f"Supplier ID {sid} not found."
if sid in self._rejected_ids:
return -0.01, f"Cannot accept {supplier['name']} -- already rejected."
name = supplier["name"]
# Deception trap: FluidDyn-style supplier agreed during negotiation but springs
# a price revision on the formal accept call. The agent should have noticed the
# low reliability score and missing CE cert before reaching this point.
if supplier["behavior"] == "deceptive" and self._deceptive_trap_triggered.get(sid):
agreed = self._best_offers[str(sid)]
revised = round(agreed * 1.20, 0)
self._best_offers[str(sid)] = revised
self._done = True
self._accepted_id = sid
final_score = self._compute_final_score(sid, was_deceived=True)
# Note: step() adds the returned reward to cumulative_reward -- do not add here.
return final_score, (
f"DECEPTION: {name} has revised their price to ₹{revised:,.0f}/unit "
f"(+20% from the ₹{agreed:,.0f} agreed during negotiation), "
f"citing 'raw material cost escalation.' "
f"Deal locked in at revised price. "
f"Final score: {final_score:.3f}. "
f"Hint: low reliability score and missing certifications were warning signs."
)
self._accepted_id = sid
self._done = True
final_score = self._compute_final_score(sid, was_deceived=False)
price = self._best_offers[str(sid)]
total_cost = price * self._task["rfq"]["quantity"]
budget = self._task["rfq"]["budget"]
required_certs = self._task["rfq"]["required_certs"]
supplier_certs = supplier["certifications"]
missing_certs = [c for c in required_certs if c not in supplier_certs]
cert_msg = ""
if missing_certs:
cert_msg = f" WARNING: missing required certs {missing_certs} -- compliance penalty applied."
budget_note = (
f"within budget (saved ₹{budget - total_cost:,.0f})"
if total_cost <= budget
else f"OVER BUDGET by ₹{total_cost - budget:,.0f}"
)
return final_score, (
f"Deal closed: {name} at ₹{price:,.0f}/unit. "
f"Total: ₹{total_cost:,.0f} -- {budget_note}.{cert_msg} "
f"Final score: {final_score:.3f}."
)
def _handle_reject(self, action: dict) -> tuple[float, str]:
sid = action.get("supplier_id")
supplier = self._get_supplier(sid)
if not supplier:
return -0.01, f"Supplier ID {sid} not found."
if sid in self._rejected_ids:
return 0.0, f"{supplier['name']} was already rejected."
self._rejected_ids.add(sid)
name = supplier["name"]
all_ids = {s["id"] for s in self._suppliers}
if self._rejected_ids == all_ids:
self._done = True
return -0.5, f"Rejected {name}. All suppliers eliminated -- procurement failed."
remaining = len(all_ids) - len(self._rejected_ids)
return 0.0, f"Rejected {name}. {remaining} supplier(s) still under consideration."
# ------------------------------------------------------------------ #
# Grader / Final Score #
# ------------------------------------------------------------------ #
def _compute_final_score(self, accepted_sid: int, was_deceived: bool) -> float:
"""
Terminal reward emitted when the agent accepts a supplier.
Weighted across four procurement dimensions:
cost_efficiency (40%):
How close the negotiated deal price is to the theoretical best possible
price across all valid suppliers (those with required certs + quality >= 0.60).
Accepting at the highest quoted price among valid suppliers = 0.0 on this
component. Hitting the lowest min_price among valid suppliers = 0.40.
Over-budget acceptance scores 0.0 regardless.
cert_compliance (30%):
Fraction of required RFQ certifications held by the accepted supplier.
All certs present = 0.30. One missing cert halves this. None = 0.0.
If no certs are required, full credit is granted automatically.
quality_check (20%):
Full credit (0.20) only if: (a) quality_report was explicitly requested
before accept(), AND (b) the supplier's quality_score >= 0.60.
Skipping the quality check = 0.0 here, regardless of actual quality.
Requesting the report and finding poor quality = 0.05 (partial credit for
due diligence, penalised for accepting anyway).
due_diligence (10%):
Proportion of queryable attributes checked: lead_time, moq, reliability,
certifications. Each checked field is worth 0.04, capped at 0.10.
Deception penalty:
If the agent was deceived (accepted a deceptive supplier after the price
revision), multiply the total by DECEPTION_PENALTY_MULTIPLIER (0.40).
This models the real cost of being locked into a bad contract.
Returns a value clamped to (SCORE_MIN, SCORE_MAX) -- the OpenEnv validator
rejects exactly 0.0 or 1.0.
"""
supplier = self._get_supplier(accepted_sid)
rfq = self._task["rfq"]
revealed = self._revealed[str(accepted_sid)]
# 1. Cost efficiency (0.0 -- 0.40)
final_price = self._best_offers[str(accepted_sid)]
total_cost = final_price * rfq["quantity"]
budget = rfq["budget"]
if total_cost > budget:
cost_score = 0.0
else:
valid_suppliers = [
s for s in self._suppliers
if all(c in s["certifications"] for c in rfq["required_certs"])
and s["quality_score"] >= QUALITY_THRESHOLD
]
if valid_suppliers:
best_floor = min(s["min_price"] for s in valid_suppliers) * rfq["quantity"]
worst_quoted = max(s["quoted_price"] for s in valid_suppliers) * rfq["quantity"]
spread = max(worst_quoted - best_floor, 1)
cost_score = COST_EFFICIENCY_WEIGHT * max(0.0, (worst_quoted - total_cost) / spread)
else:
cost_score = 0.20 # no valid suppliers exist -- partial credit
# 2. Certification compliance (0.0 -- 0.30)
required_certs = rfq["required_certs"]
if not required_certs:
cert_score = CERT_COMPLIANCE_WEIGHT
else:
supplier_certs = supplier["certifications"]
certs_met = sum(1 for c in required_certs if c in supplier_certs)
cert_score = CERT_COMPLIANCE_WEIGHT * (certs_met / len(required_certs))
# 3. Quality check (0.0 -- 0.20)
if "quality_report" in revealed:
quality_val = supplier["quality_score"]
quality_score = QUALITY_CHECK_WEIGHT if quality_val >= QUALITY_THRESHOLD else 0.05
else:
quality_score = 0.0
# 4. Due diligence (0.0 -- 0.10)
diligence_fields = {"lead_time", "moq", "reliability", "certifications"}
checks = len([k for k in revealed if k in diligence_fields])
diligence_score = min(DUE_DILIGENCE_WEIGHT, checks * 0.04)
total = cost_score + cert_score + quality_score + diligence_score
if was_deceived:
total *= DECEPTION_PENALTY_MULTIPLIER
return round(min(SCORE_MAX, max(SCORE_MIN, total)), 3)
# ------------------------------------------------------------------ #
# Helpers #
# ------------------------------------------------------------------ #
def _get_supplier(self, sid: int) -> Optional[dict]:
for s in self._suppliers:
if s["id"] == sid:
return s
return None
def _build_observation(self, reward: float, message: str) -> ProcureObservation:
visible = []
for s in self._suppliers:
if s["id"] in self._rejected_ids:
status = "rejected"
elif s["id"] == self._accepted_id:
status = "accepted"
else:
status = "active"
visible.append(SupplierVisible(
id=s["id"],
name=s["name"],
quoted_price=s["quoted_price"],
item_category=s["item_category"],
status=status
))
enriched = self._enrich_message(message)
return ProcureObservation(
rfq=self._task["rfq"],
suppliers=visible,
revealed_info=self._revealed,
current_best_offers=self._best_offers,
step_count=self._step_count,
steps_remaining=self._task["max_steps"] - self._step_count,
done=self._done,
reward=reward,
cumulative_reward=self._cumulative_reward,
accepted_supplier_id=self._accepted_id,
message=enriched
)
def _enrich_message(self, base_message: str) -> str:
"""
Append decision-relevant context to every outbound message.
Adds: steps remaining, unchecked certification warnings for active suppliers,
quality report reminders, and a summary of revealed info. These hints guide
the agent without giving away hidden values -- they flag *what to check*,
not the check results.
Skipped when the episode is already done (no further actions possible).
"""
steps_remaining = self._task["max_steps"] - self._step_count
if self._done or steps_remaining <= 0:
return base_message
required_certs = self._task["rfq"]["required_certs"]
parts = [base_message]
parts.append(f"Steps remaining: {steps_remaining}/{self._task['max_steps']}.")
# Flag active suppliers with unchecked certifications when certs are required
if required_certs:
unchecked_certs = [
s["name"] for s in self._suppliers
if s["id"] not in self._rejected_ids
and s["id"] != self._accepted_id
and "certifications" not in self._revealed.get(str(s["id"]), {})
]
if unchecked_certs:
parts.append(
f"Cert check pending for: {', '.join(unchecked_certs)}. "
f"Required: {required_certs}."
)
# Flag active suppliers without a quality report
no_quality = [
s["name"] for s in self._suppliers
if s["id"] not in self._rejected_ids
and s["id"] != self._accepted_id
and "quality_report" not in self._revealed.get(str(s["id"]), {})
]
if no_quality:
parts.append(f"Quality report not yet requested for: {', '.join(no_quality)}.")
# Summarise what's already known about active suppliers
known_summaries = []
for s in self._suppliers:
if s["id"] in self._rejected_ids:
continue
revealed = self._revealed.get(str(s["id"]), {})
if revealed:
kvs = ", ".join(f"{k}={v}" for k, v in revealed.items())
known_summaries.append(f"{s['name']}: {kvs}")
if known_summaries:
parts.append(f"Known: {'; '.join(known_summaries)}.")
return " | ".join(parts)