Spaces:
Sleeping
Sleeping
| import random | |
| from typing import List, Tuple, Dict | |
| from netzero_nav.models import ( | |
| Observation, Inventory, Shipment, Order, | |
| Action, ActionType, TransportMode, PartType | |
| ) | |
| class NetZeroEnv: | |
| # Transport Mode Specs: (Cost multiplier, Speed/ETA, Carbon unit) | |
| TRANSPORT_SPECS = { | |
| TransportMode.SEA: (1.0, 10, 0.1), | |
| TransportMode.AIR: (5.0, 2, 2.0), | |
| TransportMode.RAIL: (2.5, 5, 0.5), | |
| TransportMode.ROAD: (1.5, 4, 0.8) | |
| } | |
| step_count: int | |
| inventory: Inventory | |
| active_shipments: List[Shipment] | |
| pending_orders: List[Order] | |
| carbon_total: float | |
| cash_balance: float | |
| carbon_quota: float | |
| sea_blocked_until: int | |
| done: bool | |
| def __init__(self, task: str = "easy"): | |
| self.task = task | |
| self.reset() | |
| def reset(self, seed: int = 42) -> Observation: | |
| random.seed(seed) | |
| self.step_count = 0 | |
| self.inventory = Inventory() | |
| self.active_shipments: List[Shipment] = [] | |
| self.pending_orders: List[Order] = self._generate_initial_orders() | |
| self.carbon_total = 0.0 | |
| self.cash_balance = 10000.0 | |
| self.carbon_quota = 1000.0 if self.task == "hard" else 2000.0 | |
| self.sea_blocked_until = 0 | |
| self.done = False | |
| return self._get_obs() | |
| def _generate_initial_orders(self) -> List[Order]: | |
| if self.task == "easy": | |
| return [ | |
| Order(id="ORD_001", product="EcoPhone", quantity=5, due_date=20, reward=500.0), | |
| Order(id="ORD_002", product="GreenTab", quantity=3, due_date=30, reward=800.0) | |
| ] | |
| elif self.task == "medium": | |
| self.sea_blocked_until = 15 | |
| return [ | |
| Order(id="ORD_001", product="EcoPhone", quantity=8, due_date=15, reward=800.0), | |
| Order(id="ORD_002", product="GreenTab", quantity=5, due_date=25, reward=1200.0) | |
| ] | |
| else: # hard | |
| self.sea_blocked_until = 20 | |
| self.carbon_quota = 800.0 | |
| return [ | |
| Order(id="ORD_001", product="EcoPhone", quantity=10, due_date=12, reward=1000.0), | |
| Order(id="ORD_002", product="GreenTab", quantity=10, due_date=20, reward=2000.0) | |
| ] | |
| def _get_obs(self, news: str = None) -> Observation: | |
| return Observation( | |
| step=self.step_count, | |
| inventory=self.inventory, | |
| active_shipments=self.active_shipments, | |
| pending_orders=self.pending_orders, | |
| carbon_total=self.carbon_total, | |
| carbon_quota=self.carbon_quota, | |
| cash_balance=self.cash_balance, | |
| news=news | |
| ) | |
| def step(self, action: Action) -> Tuple[Observation, float, bool, dict]: | |
| reward = 0.0 | |
| news = None | |
| info = {} | |
| # Day Advancement | |
| if action.action_type == ActionType.SKIP: | |
| self.step_count += 1 | |
| if self.sea_blocked_until > 0 and self.step_count > self.sea_blocked_until: | |
| info["news"] = "Suez route is clear again." | |
| self.sea_blocked_until = 0 | |
| # Process active shipments | |
| arrivals = [] | |
| next_shipments = [] | |
| for s in self.active_shipments: | |
| s.eta -= 1 | |
| if s.eta <= 0: | |
| arrivals.append(f"{s.quantity}x {s.part.value}") | |
| self._receive_shipment(s) | |
| else: | |
| next_shipments.append(s) | |
| self.active_shipments = next_shipments | |
| if arrivals: | |
| info["arrivals"] = arrivals | |
| # 3. Check Order Deadlines | |
| for order in self.pending_orders: | |
| if self.step_count > order.due_date: | |
| reward -= 50.0 # Late penalty | |
| # 1. Process Actions (No Time Advancement) | |
| if action.action_type == ActionType.ORDER_PARTS: | |
| reward += self._handle_order_parts(action, info) | |
| elif action.action_type == ActionType.PRODUCE: | |
| reward += self._handle_production(action, info) | |
| elif action.action_type == ActionType.OFFSET: | |
| reward += self._handle_offset(action, info) | |
| elif action.action_type == ActionType.CANCEL: | |
| reward += self._handle_cancel(action, info) | |
| # 4. Check Termination | |
| if self.step_count >= 50 or not self.pending_orders: | |
| self.done = True | |
| info["final_score"] = self._calculate_final_score() | |
| return self._get_obs(news=news), reward, self.done, info | |
| def _handle_order_parts(self, action: Action, info: dict) -> float: | |
| if not action.part_type or not action.mode or not action.quantity: | |
| return -5.0 | |
| base_cost = 10.0 * action.quantity | |
| mult, eta, carbon = self.TRANSPORT_SPECS[action.mode] | |
| # Check Disruption | |
| if action.mode == TransportMode.SEA and self.step_count < self.sea_blocked_until: | |
| info["error"] = "Suez Blocked: Sea routes unavailable" | |
| return -15.0 | |
| total_cost = base_cost * mult | |
| if self.cash_balance < total_cost: | |
| info["error"] = "Insufficient funds" | |
| return -10.0 | |
| self.cash_balance -= total_cost | |
| self.carbon_total += carbon * action.quantity | |
| merged = False | |
| for ship in self.active_shipments: | |
| if ship.part == action.part_type and ship.mode == action.mode and ship.eta == eta: | |
| ship.quantity += action.quantity | |
| ship.cost += total_cost | |
| ship.carbon_impact += carbon * action.quantity | |
| merged = True | |
| break | |
| if not merged: | |
| new_ship = Shipment( | |
| id=f"SHP_{random.randint(1000, 9999)}", | |
| part=action.part_type, | |
| quantity=action.quantity, | |
| mode=action.mode, | |
| eta=eta, | |
| carbon_impact=carbon * action.quantity, | |
| cost=total_cost | |
| ) | |
| self.active_shipments.append(new_ship) | |
| return 2.0 | |
| def _handle_cancel(self, action: Action, info: dict) -> float: | |
| if not action.shipment_id: return 0.0 | |
| for i, ship in enumerate(self.active_shipments): | |
| if ship.id == action.shipment_id: | |
| # Refund | |
| self.cash_balance += ship.cost | |
| self.carbon_total = max(0.0, self.carbon_total - ship.carbon_impact) | |
| self.active_shipments.pop(i) | |
| return 0.0 | |
| return 0.0 | |
| def _receive_shipment(self, ship: Shipment): | |
| current_val = getattr(self.inventory, ship.part.value) | |
| setattr(self.inventory, ship.part.value, current_val + ship.quantity) | |
| def _handle_production(self, action: Action, info: dict) -> float: | |
| if not action.product: return -5.0 | |
| qty = action.quantity if action.quantity else 1 | |
| # Determine part requirements depending on product. To keep simulation clean, both use chips. | |
| req_chips = qty | |
| req_sensors = qty if action.product == "EcoPhone" else 0 | |
| req_batteries = qty if action.product == "GreenTab" else 0 | |
| if self.inventory.chips >= req_chips and self.inventory.sensors >= req_sensors and self.inventory.batteries >= req_batteries: | |
| self.inventory.chips -= req_chips | |
| self.inventory.sensors -= req_sensors | |
| self.inventory.batteries -= req_batteries | |
| total_reward = 10.0 * qty | |
| remaining_produce = qty | |
| orders_to_remove = [] | |
| for o in self.pending_orders: | |
| if o.product == action.product and remaining_produce > 0: | |
| fulfilled = min(o.quantity, remaining_produce) | |
| o.quantity -= fulfilled | |
| remaining_produce -= fulfilled | |
| if o.quantity <= 0: | |
| orders_to_remove.append(o) | |
| total_reward += o.reward | |
| self.cash_balance += o.reward | |
| for o in orders_to_remove: | |
| self.pending_orders.remove(o) | |
| return total_reward | |
| else: | |
| info["error"] = "Missing parts for run" | |
| return -10.0 | |
| def _handle_offset(self, action: Action, info: dict) -> float: | |
| if not action.offset_amount: return -5.0 | |
| if self.carbon_total <= 0: | |
| info["error"] = "No carbon footprint to offset" | |
| return -5.0 | |
| cost = action.offset_amount * 2.0 | |
| if self.cash_balance >= cost: | |
| self.cash_balance -= cost | |
| self.carbon_total = max(0.0, self.carbon_total - action.offset_amount) | |
| return 5.0 | |
| else: | |
| info["error"] = "Insufficient funds for offset" | |
| return -10.0 | |
| def _handle_reroute(self, action: Action, info: dict) -> float: | |
| return 0.0 | |
| def _calculate_final_score(self) -> float: | |
| # Score is primarily based on order fulfillment, penalized by carbon overages | |
| total_orders = sum(o.quantity for o in self._generate_initial_orders()) | |
| remaining = sum(o.quantity for o in self.pending_orders) | |
| fulfilled_ratio = (total_orders - remaining) / total_orders if total_orders > 0 else 1.0 | |
| score = fulfilled_ratio | |
| # Penalize if carbon quota exceeded | |
| if self.carbon_total > self.carbon_quota: | |
| overage = self.carbon_total - self.carbon_quota | |
| penalty = (overage / self.carbon_quota) * 0.5 | |
| score = max(0.0, score - penalty) | |
| return float(max(0.01, min(0.99, score))) | |