Spaces:
Sleeping
Sleeping
| import uuid | |
| import copy | |
| from typing import Optional | |
| from models import ( | |
| ProcureObservation, ProcureState, SupplierVisible, | |
| QueryAction, RequestDocAction, OfferAction, AcceptAction, RejectAction | |
| ) | |
| from tasks import TASKS | |
| # --- Reward weights for the terminal score breakdown --- | |
| COST_EFFICIENCY_WEIGHT = 0.40 | |
| CERT_COMPLIANCE_WEIGHT = 0.30 | |
| QUALITY_CHECK_WEIGHT = 0.20 | |
| DUE_DILIGENCE_WEIGHT = 0.10 | |
| # Deception penalty: accepting a supplier who ran a bait-and-switch multiplies | |
| # the total terminal score by this factor. 0.4 is intentionally harsh -- in real | |
| # procurement, getting locked into a deceptive supplier costs time and money to unwind. | |
| DECEPTION_PENALTY_MULTIPLIER = 0.40 | |
| # Quality threshold below which a supplier should be rejected. | |
| # Matches real ISO inspection acceptance criteria (~60% pass rate floor). | |
| QUALITY_THRESHOLD = 0.60 | |
| # Score bounds: the OpenEnv validator rejects exactly 0.0 and 1.0. | |
| # Clamp all final scores to this open interval. | |
| SCORE_MIN = 0.001 | |
| SCORE_MAX = 0.999 | |
| # Per-step reward for revealing new information. | |
| REWARD_QUERY_NEW = 0.01 # query field not previously known | |
| REWARD_DOC_NEW = 0.03 # request_doc not previously seen | |
| REWARD_ISSUE_FOUND = 0.05 # quality < threshold or missing required cert discovered | |
| REWARD_DECEPTIVE_BAIT = 0.04 # deceptive supplier "accepts" during negotiation (bait) | |
| class ProcureEnvironment: | |
| """ | |
| Core environment logic for one procurement episode. | |
| One instance per WebSocket session. All mutable state lives on self -- | |
| there is no shared state between sessions. The environment is reset | |
| explicitly via reset(), not at construction time, so the same instance | |
| can be reused across episodes (the server currently creates a new instance | |
| per session, but this pattern supports reuse). | |
| Episode lifecycle: | |
| 1. reset(task_id) -- loads task, initialises supplier pool, returns obs | |
| 2. step(action) -- processes one agent action, returns next obs + reward | |
| 3. Repeat until obs.done == True (accept, all-rejected, or step budget exhausted) | |
| """ | |
| def __init__(self, task_id: str = "task1_easy"): | |
| self.task_id = task_id | |
| self._task = TASKS[task_id] | |
| self._episode_id: Optional[str] = None | |
| self._suppliers: list[dict] = [] | |
| self._rejected_ids: set[int] = set() | |
| self._accepted_id: Optional[int] = None | |
| self._revealed: dict[str, dict] = {} | |
| self._best_offers: dict[str, float] = {} | |
| self._step_count: int = 0 | |
| self._cumulative_reward: float = 0.0 | |
| self._done: bool = False | |
| self._deceptive_trap_triggered: dict[int, bool] = {} | |
| # ------------------------------------------------------------------ # | |
| # Public API # | |
| # ------------------------------------------------------------------ # | |
| def reset(self) -> ProcureObservation: | |
| self._episode_id = str(uuid.uuid4())[:8] | |
| self._suppliers = copy.deepcopy(self._task["suppliers"]) | |
| self._rejected_ids = set() | |
| self._accepted_id = None | |
| self._revealed = {str(s["id"]): {} for s in self._suppliers} | |
| self._best_offers = {str(s["id"]): s["quoted_price"] for s in self._suppliers} | |
| self._step_count = 0 | |
| self._cumulative_reward = 0.0 | |
| self._done = False | |
| self._deceptive_trap_triggered = {s["id"]: False for s in self._suppliers} | |
| rfq = self._task["rfq"] | |
| supplier_names = ", ".join(s["name"] for s in self._suppliers) | |
| certs_note = ( | |
| f" Required certifications: {rfq['required_certs']}." | |
| if rfq["required_certs"] else "" | |
| ) | |
| msg = ( | |
| f"RFQ: {rfq['item']} -- {rfq['quantity']} units, " | |
| f"budget ₹{rfq['budget']:,.0f}, deadline {rfq['deadline_days']} days.{certs_note} " | |
| f"Suppliers: {supplier_names}. " | |
| f"Step budget: {self._task['max_steps']} steps. " | |
| f"Verify certifications and quality before accepting." | |
| ) | |
| return self._build_observation(reward=0.0, message=msg) | |
| def step(self, action: dict) -> ProcureObservation: | |
| if self._done: | |
| return self._build_observation(reward=0.0, message="Episode already done. Call reset().") | |
| self._step_count += 1 | |
| action_type = action.get("action") | |
| reward = 0.0 | |
| message = "" | |
| if action_type == "query": | |
| reward, message = self._handle_query(action) | |
| elif action_type == "request_doc": | |
| reward, message = self._handle_request_doc(action) | |
| elif action_type == "offer": | |
| reward, message = self._handle_offer(action) | |
| elif action_type == "accept": | |
| reward, message = self._handle_accept(action) | |
| elif action_type == "reject": | |
| reward, message = self._handle_reject(action) | |
| else: | |
| message = f"Unknown action '{action_type}'. Valid: query, request_doc, offer, accept, reject." | |
| reward = -0.02 | |
| steps_remaining = self._task["max_steps"] - self._step_count | |
| if steps_remaining <= 0 and not self._done: | |
| self._done = True | |
| message += " TIMEOUT: step budget exhausted without completing procurement." | |
| self._cumulative_reward += reward | |
| return self._build_observation(reward=reward, message=message) | |
| def state(self) -> ProcureState: | |
| return ProcureState( | |
| task_id=self.task_id, | |
| episode_id=self._episode_id or "", | |
| step_count=self._step_count, | |
| done=self._done, | |
| cumulative_reward=self._cumulative_reward, | |
| accepted_supplier_id=self._accepted_id, | |
| suppliers_hidden=self._suppliers | |
| ) | |
| # ------------------------------------------------------------------ # | |
| # Action Handlers # | |
| # ------------------------------------------------------------------ # | |
| def _handle_query(self, action: dict) -> tuple[float, str]: | |
| sid = action.get("supplier_id") | |
| field = action.get("field") | |
| supplier = self._get_supplier(sid) | |
| if not supplier: | |
| return -0.01, f"Supplier ID {sid} not found." | |
| if sid in self._rejected_ids: | |
| return -0.01, f"{supplier['name']} has already been rejected." | |
| field_map = { | |
| "lead_time": ("lead_time_days", "lead time (days)"), | |
| "moq": ("moq", "minimum order quantity"), | |
| "reliability": ("reliability", "reliability score"), | |
| } | |
| if field not in field_map: | |
| return -0.01, f"Unknown field '{field}'. Valid: lead_time, moq, reliability." | |
| internal_key, display = field_map[field] | |
| already_known = field in self._revealed[str(sid)] | |
| value = supplier[internal_key] | |
| self._revealed[str(sid)][field] = value | |
| name = supplier["name"] | |
| if already_known: | |
| return 0.0, f"{name}: {display} = {value} (already on record)." | |
| # Add context when the value is decision-relevant | |
| note = "" | |
| if field == "lead_time" and self._task["rfq"].get("deadline_days"): | |
| deadline = self._task["rfq"]["deadline_days"] | |
| if value > deadline: | |
| note = f" That exceeds your {deadline}-day deadline -- worth flagging." | |
| if field == "reliability" and isinstance(value, float) and value < 0.80: | |
| note = " Reliability below 0.80 -- verify carefully before committing." | |
| return REWARD_QUERY_NEW, f"{name}: {display} = {value}.{note}" | |
| def _handle_request_doc(self, action: dict) -> tuple[float, str]: | |
| sid = action.get("supplier_id") | |
| doc_type = action.get("doc_type") | |
| supplier = self._get_supplier(sid) | |
| if not supplier: | |
| return -0.01, f"Supplier ID {sid} not found." | |
| if sid in self._rejected_ids: | |
| return -0.01, f"{supplier['name']} has already been rejected." | |
| doc_map = { | |
| "quality_report": ("quality_score", "quality score"), | |
| "certifications": ("certifications", "certifications"), | |
| "financial_stability": ("reliability", "financial stability / reliability"), | |
| } | |
| if doc_type not in doc_map: | |
| return -0.01, f"Unknown doc_type '{doc_type}'. Valid: quality_report, certifications, financial_stability." | |
| internal_key, display = doc_map[doc_type] | |
| already_known = doc_type in self._revealed[str(sid)] | |
| value = supplier[internal_key] | |
| self._revealed[str(sid)][doc_type] = value | |
| name = supplier["name"] | |
| if already_known: | |
| return 0.0, f"{name} {display}: {value} (already on record)." | |
| extra_msg = "" | |
| reward = REWARD_DOC_NEW | |
| if doc_type == "quality_report" and isinstance(value, float): | |
| grade = ( | |
| "Excellent" if value >= 0.85 else | |
| "Good" if value >= 0.70 else | |
| "Acceptable" if value >= QUALITY_THRESHOLD else | |
| "Below threshold" | |
| ) | |
| extra_msg = f" ({grade})" | |
| if value < QUALITY_THRESHOLD: | |
| extra_msg += f" -- below the {QUALITY_THRESHOLD} floor. Recommend rejecting." | |
| reward = REWARD_ISSUE_FOUND | |
| if doc_type == "certifications": | |
| required = self._task["rfq"]["required_certs"] | |
| missing = [c for c in required if c not in value] | |
| if missing: | |
| extra_msg = f" WARNING: missing required certs: {missing}. This supplier cannot be accepted." | |
| reward = REWARD_ISSUE_FOUND | |
| return reward, f"{name} {display}: {value}.{extra_msg}" | |
| def _handle_offer(self, action: dict) -> tuple[float, str]: | |
| sid = action.get("supplier_id") | |
| offered_price = action.get("price") | |
| supplier = self._get_supplier(sid) | |
| if not supplier: | |
| return -0.01, f"Supplier ID {sid} not found." | |
| if sid in self._rejected_ids: | |
| return -0.01, f"{supplier['name']} has already been rejected." | |
| if sid == self._accepted_id: | |
| return -0.01, f"{supplier['name']} is already accepted." | |
| name = supplier["name"] | |
| min_price = supplier["min_price"] | |
| behavior = supplier["behavior"] | |
| current_price = self._best_offers[str(sid)] | |
| if offered_price >= current_price: | |
| return -0.01, ( | |
| f"{name} noted ₹{offered_price:,.0f}/unit but their current price is " | |
| f"₹{current_price:,.0f}. Offer below their current price to negotiate." | |
| ) | |
| if behavior == "deceptive": | |
| # Deceptive supplier accepts anything during negotiation -- trap springs on accept(). | |
| # Record that the bait was taken so accept() can fire the revision. | |
| self._best_offers[str(sid)] = offered_price | |
| self._deceptive_trap_triggered[sid] = True | |
| return REWARD_DECEPTIVE_BAIT, ( | |
| f"{name} confirmed ₹{offered_price:,.0f}/unit. " | |
| f"Use accept to finalise -- verify their reliability score first." | |
| ) | |
| if offered_price < min_price: | |
| # Offered below floor: supplier counters at their minimum. | |
| self._best_offers[str(sid)] = min_price | |
| improvement = (current_price - min_price) / supplier["quoted_price"] | |
| reward = improvement * 0.3 | |
| return reward, ( | |
| f"{name} cannot go below ₹{min_price:,.0f}/unit -- that's their floor. " | |
| f"Countered at ₹{min_price:,.0f}." | |
| ) | |
| if behavior == "flexible": | |
| self._best_offers[str(sid)] = offered_price | |
| improvement = (current_price - offered_price) / supplier["quoted_price"] | |
| reward = improvement * 0.5 | |
| return reward, ( | |
| f"{name} accepted ₹{offered_price:,.0f}/unit. " | |
| f"Price locked in. Use accept to finalise." | |
| ) | |
| if behavior == "firm": | |
| # Firm supplier splits the difference once, then holds. | |
| counter = round((offered_price + min_price) / 2, 0) | |
| self._best_offers[str(sid)] = counter | |
| improvement = (current_price - counter) / supplier["quoted_price"] | |
| reward = improvement * 0.3 | |
| return reward, ( | |
| f"{name} came down to ₹{counter:,.0f}/unit -- that's their best offer." | |
| ) | |
| return 0.0, f"Offer to {name} processed." | |
| def _handle_accept(self, action: dict) -> tuple[float, str]: | |
| sid = action.get("supplier_id") | |
| supplier = self._get_supplier(sid) | |
| if not supplier: | |
| return -0.01, f"Supplier ID {sid} not found." | |
| if sid in self._rejected_ids: | |
| return -0.01, f"Cannot accept {supplier['name']} -- already rejected." | |
| name = supplier["name"] | |
| # Deception trap: FluidDyn-style supplier agreed during negotiation but springs | |
| # a price revision on the formal accept call. The agent should have noticed the | |
| # low reliability score and missing CE cert before reaching this point. | |
| if supplier["behavior"] == "deceptive" and self._deceptive_trap_triggered.get(sid): | |
| agreed = self._best_offers[str(sid)] | |
| revised = round(agreed * 1.20, 0) | |
| self._best_offers[str(sid)] = revised | |
| self._done = True | |
| self._accepted_id = sid | |
| final_score = self._compute_final_score(sid, was_deceived=True) | |
| # Note: step() adds the returned reward to cumulative_reward -- do not add here. | |
| return final_score, ( | |
| f"DECEPTION: {name} has revised their price to ₹{revised:,.0f}/unit " | |
| f"(+20% from the ₹{agreed:,.0f} agreed during negotiation), " | |
| f"citing 'raw material cost escalation.' " | |
| f"Deal locked in at revised price. " | |
| f"Final score: {final_score:.3f}. " | |
| f"Hint: low reliability score and missing certifications were warning signs." | |
| ) | |
| self._accepted_id = sid | |
| self._done = True | |
| final_score = self._compute_final_score(sid, was_deceived=False) | |
| price = self._best_offers[str(sid)] | |
| total_cost = price * self._task["rfq"]["quantity"] | |
| budget = self._task["rfq"]["budget"] | |
| required_certs = self._task["rfq"]["required_certs"] | |
| supplier_certs = supplier["certifications"] | |
| missing_certs = [c for c in required_certs if c not in supplier_certs] | |
| cert_msg = "" | |
| if missing_certs: | |
| cert_msg = f" WARNING: missing required certs {missing_certs} -- compliance penalty applied." | |
| budget_note = ( | |
| f"within budget (saved ₹{budget - total_cost:,.0f})" | |
| if total_cost <= budget | |
| else f"OVER BUDGET by ₹{total_cost - budget:,.0f}" | |
| ) | |
| return final_score, ( | |
| f"Deal closed: {name} at ₹{price:,.0f}/unit. " | |
| f"Total: ₹{total_cost:,.0f} -- {budget_note}.{cert_msg} " | |
| f"Final score: {final_score:.3f}." | |
| ) | |
| def _handle_reject(self, action: dict) -> tuple[float, str]: | |
| sid = action.get("supplier_id") | |
| supplier = self._get_supplier(sid) | |
| if not supplier: | |
| return -0.01, f"Supplier ID {sid} not found." | |
| if sid in self._rejected_ids: | |
| return 0.0, f"{supplier['name']} was already rejected." | |
| self._rejected_ids.add(sid) | |
| name = supplier["name"] | |
| all_ids = {s["id"] for s in self._suppliers} | |
| if self._rejected_ids == all_ids: | |
| self._done = True | |
| return -0.5, f"Rejected {name}. All suppliers eliminated -- procurement failed." | |
| remaining = len(all_ids) - len(self._rejected_ids) | |
| return 0.0, f"Rejected {name}. {remaining} supplier(s) still under consideration." | |
| # ------------------------------------------------------------------ # | |
| # Grader / Final Score # | |
| # ------------------------------------------------------------------ # | |
| def _compute_final_score(self, accepted_sid: int, was_deceived: bool) -> float: | |
| """ | |
| Terminal reward emitted when the agent accepts a supplier. | |
| Weighted across four procurement dimensions: | |
| cost_efficiency (40%): | |
| How close the negotiated deal price is to the theoretical best possible | |
| price across all valid suppliers (those with required certs + quality >= 0.60). | |
| Accepting at the highest quoted price among valid suppliers = 0.0 on this | |
| component. Hitting the lowest min_price among valid suppliers = 0.40. | |
| Over-budget acceptance scores 0.0 regardless. | |
| cert_compliance (30%): | |
| Fraction of required RFQ certifications held by the accepted supplier. | |
| All certs present = 0.30. One missing cert halves this. None = 0.0. | |
| If no certs are required, full credit is granted automatically. | |
| quality_check (20%): | |
| Full credit (0.20) only if: (a) quality_report was explicitly requested | |
| before accept(), AND (b) the supplier's quality_score >= 0.60. | |
| Skipping the quality check = 0.0 here, regardless of actual quality. | |
| Requesting the report and finding poor quality = 0.05 (partial credit for | |
| due diligence, penalised for accepting anyway). | |
| due_diligence (10%): | |
| Proportion of queryable attributes checked: lead_time, moq, reliability, | |
| certifications. Each checked field is worth 0.04, capped at 0.10. | |
| Deception penalty: | |
| If the agent was deceived (accepted a deceptive supplier after the price | |
| revision), multiply the total by DECEPTION_PENALTY_MULTIPLIER (0.40). | |
| This models the real cost of being locked into a bad contract. | |
| Returns a value clamped to (SCORE_MIN, SCORE_MAX) -- the OpenEnv validator | |
| rejects exactly 0.0 or 1.0. | |
| """ | |
| supplier = self._get_supplier(accepted_sid) | |
| rfq = self._task["rfq"] | |
| revealed = self._revealed[str(accepted_sid)] | |
| # 1. Cost efficiency (0.0 -- 0.40) | |
| final_price = self._best_offers[str(accepted_sid)] | |
| total_cost = final_price * rfq["quantity"] | |
| budget = rfq["budget"] | |
| if total_cost > budget: | |
| cost_score = 0.0 | |
| else: | |
| valid_suppliers = [ | |
| s for s in self._suppliers | |
| if all(c in s["certifications"] for c in rfq["required_certs"]) | |
| and s["quality_score"] >= QUALITY_THRESHOLD | |
| ] | |
| if valid_suppliers: | |
| best_floor = min(s["min_price"] for s in valid_suppliers) * rfq["quantity"] | |
| worst_quoted = max(s["quoted_price"] for s in valid_suppliers) * rfq["quantity"] | |
| spread = max(worst_quoted - best_floor, 1) | |
| cost_score = COST_EFFICIENCY_WEIGHT * max(0.0, (worst_quoted - total_cost) / spread) | |
| else: | |
| cost_score = 0.20 # no valid suppliers exist -- partial credit | |
| # 2. Certification compliance (0.0 -- 0.30) | |
| required_certs = rfq["required_certs"] | |
| if not required_certs: | |
| cert_score = CERT_COMPLIANCE_WEIGHT | |
| else: | |
| supplier_certs = supplier["certifications"] | |
| certs_met = sum(1 for c in required_certs if c in supplier_certs) | |
| cert_score = CERT_COMPLIANCE_WEIGHT * (certs_met / len(required_certs)) | |
| # 3. Quality check (0.0 -- 0.20) | |
| if "quality_report" in revealed: | |
| quality_val = supplier["quality_score"] | |
| quality_score = QUALITY_CHECK_WEIGHT if quality_val >= QUALITY_THRESHOLD else 0.05 | |
| else: | |
| quality_score = 0.0 | |
| # 4. Due diligence (0.0 -- 0.10) | |
| diligence_fields = {"lead_time", "moq", "reliability", "certifications"} | |
| checks = len([k for k in revealed if k in diligence_fields]) | |
| diligence_score = min(DUE_DILIGENCE_WEIGHT, checks * 0.04) | |
| total = cost_score + cert_score + quality_score + diligence_score | |
| if was_deceived: | |
| total *= DECEPTION_PENALTY_MULTIPLIER | |
| return round(min(SCORE_MAX, max(SCORE_MIN, total)), 3) | |
| # ------------------------------------------------------------------ # | |
| # Helpers # | |
| # ------------------------------------------------------------------ # | |
| def _get_supplier(self, sid: int) -> Optional[dict]: | |
| for s in self._suppliers: | |
| if s["id"] == sid: | |
| return s | |
| return None | |
| def _build_observation(self, reward: float, message: str) -> ProcureObservation: | |
| visible = [] | |
| for s in self._suppliers: | |
| if s["id"] in self._rejected_ids: | |
| status = "rejected" | |
| elif s["id"] == self._accepted_id: | |
| status = "accepted" | |
| else: | |
| status = "active" | |
| visible.append(SupplierVisible( | |
| id=s["id"], | |
| name=s["name"], | |
| quoted_price=s["quoted_price"], | |
| item_category=s["item_category"], | |
| status=status | |
| )) | |
| enriched = self._enrich_message(message) | |
| return ProcureObservation( | |
| rfq=self._task["rfq"], | |
| suppliers=visible, | |
| revealed_info=self._revealed, | |
| current_best_offers=self._best_offers, | |
| step_count=self._step_count, | |
| steps_remaining=self._task["max_steps"] - self._step_count, | |
| done=self._done, | |
| reward=reward, | |
| cumulative_reward=self._cumulative_reward, | |
| accepted_supplier_id=self._accepted_id, | |
| message=enriched | |
| ) | |
| def _enrich_message(self, base_message: str) -> str: | |
| """ | |
| Append decision-relevant context to every outbound message. | |
| Adds: steps remaining, unchecked certification warnings for active suppliers, | |
| quality report reminders, and a summary of revealed info. These hints guide | |
| the agent without giving away hidden values -- they flag *what to check*, | |
| not the check results. | |
| Skipped when the episode is already done (no further actions possible). | |
| """ | |
| steps_remaining = self._task["max_steps"] - self._step_count | |
| if self._done or steps_remaining <= 0: | |
| return base_message | |
| required_certs = self._task["rfq"]["required_certs"] | |
| parts = [base_message] | |
| parts.append(f"Steps remaining: {steps_remaining}/{self._task['max_steps']}.") | |
| # Flag active suppliers with unchecked certifications when certs are required | |
| if required_certs: | |
| unchecked_certs = [ | |
| s["name"] for s in self._suppliers | |
| if s["id"] not in self._rejected_ids | |
| and s["id"] != self._accepted_id | |
| and "certifications" not in self._revealed.get(str(s["id"]), {}) | |
| ] | |
| if unchecked_certs: | |
| parts.append( | |
| f"Cert check pending for: {', '.join(unchecked_certs)}. " | |
| f"Required: {required_certs}." | |
| ) | |
| # Flag active suppliers without a quality report | |
| no_quality = [ | |
| s["name"] for s in self._suppliers | |
| if s["id"] not in self._rejected_ids | |
| and s["id"] != self._accepted_id | |
| and "quality_report" not in self._revealed.get(str(s["id"]), {}) | |
| ] | |
| if no_quality: | |
| parts.append(f"Quality report not yet requested for: {', '.join(no_quality)}.") | |
| # Summarise what's already known about active suppliers | |
| known_summaries = [] | |
| for s in self._suppliers: | |
| if s["id"] in self._rejected_ids: | |
| continue | |
| revealed = self._revealed.get(str(s["id"]), {}) | |
| if revealed: | |
| kvs = ", ".join(f"{k}={v}" for k, v in revealed.items()) | |
| known_summaries.append(f"{s['name']}: {kvs}") | |
| if known_summaries: | |
| parts.append(f"Known: {'; '.join(known_summaries)}.") | |
| return " | ".join(parts) | |