Aryanshh
Compliance: Force strict (0, 1) score range in both env.py and inference.py logs
d57c77b
import random
from typing import List, Tuple, Dict
from netzero_nav.models import (
Observation, Inventory, Shipment, Order,
Action, ActionType, TransportMode, PartType
)
class NetZeroEnv:
# Transport Mode Specs: (Cost multiplier, Speed/ETA, Carbon unit)
TRANSPORT_SPECS = {
TransportMode.SEA: (1.0, 10, 0.1),
TransportMode.AIR: (5.0, 2, 2.0),
TransportMode.RAIL: (2.5, 5, 0.5),
TransportMode.ROAD: (1.5, 4, 0.8)
}
step_count: int
inventory: Inventory
active_shipments: List[Shipment]
pending_orders: List[Order]
carbon_total: float
cash_balance: float
carbon_quota: float
sea_blocked_until: int
done: bool
def __init__(self, task: str = "easy"):
self.task = task
self.reset()
def reset(self, seed: int = 42) -> Observation:
random.seed(seed)
self.step_count = 0
self.inventory = Inventory()
self.active_shipments: List[Shipment] = []
self.pending_orders: List[Order] = self._generate_initial_orders()
self.carbon_total = 0.0
self.cash_balance = 10000.0
self.carbon_quota = 1000.0 if self.task == "hard" else 2000.0
self.sea_blocked_until = 0
self.done = False
return self._get_obs()
def _generate_initial_orders(self) -> List[Order]:
if self.task == "easy":
return [
Order(id="ORD_001", product="EcoPhone", quantity=5, due_date=20, reward=500.0),
Order(id="ORD_002", product="GreenTab", quantity=3, due_date=30, reward=800.0)
]
elif self.task == "medium":
self.sea_blocked_until = 15
return [
Order(id="ORD_001", product="EcoPhone", quantity=8, due_date=15, reward=800.0),
Order(id="ORD_002", product="GreenTab", quantity=5, due_date=25, reward=1200.0)
]
else: # hard
self.sea_blocked_until = 20
self.carbon_quota = 800.0
return [
Order(id="ORD_001", product="EcoPhone", quantity=10, due_date=12, reward=1000.0),
Order(id="ORD_002", product="GreenTab", quantity=10, due_date=20, reward=2000.0)
]
def _get_obs(self, news: str = None) -> Observation:
return Observation(
step=self.step_count,
inventory=self.inventory,
active_shipments=self.active_shipments,
pending_orders=self.pending_orders,
carbon_total=self.carbon_total,
carbon_quota=self.carbon_quota,
cash_balance=self.cash_balance,
news=news
)
def step(self, action: Action) -> Tuple[Observation, float, bool, dict]:
reward = 0.0
news = None
info = {}
# Day Advancement
if action.action_type == ActionType.SKIP:
self.step_count += 1
if self.sea_blocked_until > 0 and self.step_count > self.sea_blocked_until:
info["news"] = "Suez route is clear again."
self.sea_blocked_until = 0
# Process active shipments
arrivals = []
next_shipments = []
for s in self.active_shipments:
s.eta -= 1
if s.eta <= 0:
arrivals.append(f"{s.quantity}x {s.part.value}")
self._receive_shipment(s)
else:
next_shipments.append(s)
self.active_shipments = next_shipments
if arrivals:
info["arrivals"] = arrivals
# 3. Check Order Deadlines
for order in self.pending_orders:
if self.step_count > order.due_date:
reward -= 50.0 # Late penalty
# 1. Process Actions (No Time Advancement)
if action.action_type == ActionType.ORDER_PARTS:
reward += self._handle_order_parts(action, info)
elif action.action_type == ActionType.PRODUCE:
reward += self._handle_production(action, info)
elif action.action_type == ActionType.OFFSET:
reward += self._handle_offset(action, info)
elif action.action_type == ActionType.CANCEL:
reward += self._handle_cancel(action, info)
# 4. Check Termination
if self.step_count >= 50 or not self.pending_orders:
self.done = True
info["final_score"] = self._calculate_final_score()
return self._get_obs(news=news), reward, self.done, info
def _handle_order_parts(self, action: Action, info: dict) -> float:
if not action.part_type or not action.mode or not action.quantity:
return -5.0
base_cost = 10.0 * action.quantity
mult, eta, carbon = self.TRANSPORT_SPECS[action.mode]
# Check Disruption
if action.mode == TransportMode.SEA and self.step_count < self.sea_blocked_until:
info["error"] = "Suez Blocked: Sea routes unavailable"
return -15.0
total_cost = base_cost * mult
if self.cash_balance < total_cost:
info["error"] = "Insufficient funds"
return -10.0
self.cash_balance -= total_cost
self.carbon_total += carbon * action.quantity
merged = False
for ship in self.active_shipments:
if ship.part == action.part_type and ship.mode == action.mode and ship.eta == eta:
ship.quantity += action.quantity
ship.cost += total_cost
ship.carbon_impact += carbon * action.quantity
merged = True
break
if not merged:
new_ship = Shipment(
id=f"SHP_{random.randint(1000, 9999)}",
part=action.part_type,
quantity=action.quantity,
mode=action.mode,
eta=eta,
carbon_impact=carbon * action.quantity,
cost=total_cost
)
self.active_shipments.append(new_ship)
return 2.0
def _handle_cancel(self, action: Action, info: dict) -> float:
if not action.shipment_id: return 0.0
for i, ship in enumerate(self.active_shipments):
if ship.id == action.shipment_id:
# Refund
self.cash_balance += ship.cost
self.carbon_total = max(0.0, self.carbon_total - ship.carbon_impact)
self.active_shipments.pop(i)
return 0.0
return 0.0
def _receive_shipment(self, ship: Shipment):
current_val = getattr(self.inventory, ship.part.value)
setattr(self.inventory, ship.part.value, current_val + ship.quantity)
def _handle_production(self, action: Action, info: dict) -> float:
if not action.product: return -5.0
qty = action.quantity if action.quantity else 1
# Determine part requirements depending on product. To keep simulation clean, both use chips.
req_chips = qty
req_sensors = qty if action.product == "EcoPhone" else 0
req_batteries = qty if action.product == "GreenTab" else 0
if self.inventory.chips >= req_chips and self.inventory.sensors >= req_sensors and self.inventory.batteries >= req_batteries:
self.inventory.chips -= req_chips
self.inventory.sensors -= req_sensors
self.inventory.batteries -= req_batteries
total_reward = 10.0 * qty
remaining_produce = qty
orders_to_remove = []
for o in self.pending_orders:
if o.product == action.product and remaining_produce > 0:
fulfilled = min(o.quantity, remaining_produce)
o.quantity -= fulfilled
remaining_produce -= fulfilled
if o.quantity <= 0:
orders_to_remove.append(o)
total_reward += o.reward
self.cash_balance += o.reward
for o in orders_to_remove:
self.pending_orders.remove(o)
return total_reward
else:
info["error"] = "Missing parts for run"
return -10.0
def _handle_offset(self, action: Action, info: dict) -> float:
if not action.offset_amount: return -5.0
if self.carbon_total <= 0:
info["error"] = "No carbon footprint to offset"
return -5.0
cost = action.offset_amount * 2.0
if self.cash_balance >= cost:
self.cash_balance -= cost
self.carbon_total = max(0.0, self.carbon_total - action.offset_amount)
return 5.0
else:
info["error"] = "Insufficient funds for offset"
return -10.0
def _handle_reroute(self, action: Action, info: dict) -> float:
return 0.0
def _calculate_final_score(self) -> float:
# Score is primarily based on order fulfillment, penalized by carbon overages
total_orders = sum(o.quantity for o in self._generate_initial_orders())
remaining = sum(o.quantity for o in self.pending_orders)
fulfilled_ratio = (total_orders - remaining) / total_orders if total_orders > 0 else 1.0
score = fulfilled_ratio
# Penalize if carbon quota exceeded
if self.carbon_total > self.carbon_quota:
overage = self.carbon_total - self.carbon_quota
penalty = (overage / self.carbon_quota) * 0.5
score = max(0.0, score - penalty)
return float(max(0.01, min(0.99, score)))