File size: 9,798 Bytes
e3e3da2
 
0dd5d30
e3e3da2
 
 
 
6b8c880
e3e3da2
 
 
 
 
 
 
 
a95dc70
 
 
 
 
 
 
 
 
 
e3e3da2
 
 
 
 
 
 
 
 
 
 
 
 
2ddf5f5
e3e3da2
 
 
 
a95dc70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3e3da2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ddf5f5
e3e3da2
 
87e9783
c547cb1
 
87e9783
 
 
c547cb1
87e9783
 
 
 
 
 
 
 
 
 
 
 
 
c547cb1
 
 
 
 
 
 
e3e3da2
 
 
 
 
 
c547cb1
 
e3e3da2
 
 
 
 
 
c547cb1
e3e3da2
 
 
 
 
 
 
2ddf5f5
 
 
 
 
 
e3e3da2
 
 
 
 
 
 
 
 
c547cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3e3da2
 
c547cb1
 
 
 
 
 
 
 
 
 
 
 
e3e3da2
 
 
 
 
 
eaea692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3e3da2
 
 
51cbaeb
 
 
 
e3e3da2
 
 
 
 
51cbaeb
 
 
e3e3da2
 
 
 
 
a95dc70
 
 
 
 
 
 
 
 
 
 
 
 
d57c77b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import random
from typing import List, Tuple, Dict
from netzero_nav.models import (
    Observation, Inventory, Shipment, Order, 
    Action, ActionType, TransportMode, PartType
)

class NetZeroEnv:
    # Transport Mode Specs: (Cost multiplier, Speed/ETA, Carbon unit)
    TRANSPORT_SPECS = {
        TransportMode.SEA:  (1.0, 10, 0.1),
        TransportMode.AIR:  (5.0, 2,  2.0),
        TransportMode.RAIL: (2.5, 5,  0.5),
        TransportMode.ROAD: (1.5, 4,  0.8)
    }

    step_count: int
    inventory: Inventory
    active_shipments: List[Shipment]
    pending_orders: List[Order]
    carbon_total: float
    cash_balance: float
    carbon_quota: float
    sea_blocked_until: int
    done: bool

    def __init__(self, task: str = "easy"):
        self.task = task
        self.reset()

    def reset(self, seed: int = 42) -> Observation:
        random.seed(seed)
        self.step_count = 0
        self.inventory = Inventory()
        self.active_shipments: List[Shipment] = []
        self.pending_orders: List[Order] = self._generate_initial_orders()
        self.carbon_total = 0.0
        self.cash_balance = 10000.0
        self.carbon_quota = 1000.0 if self.task == "hard" else 2000.0
        self.sea_blocked_until = 0
        self.done = False
        return self._get_obs()

    def _generate_initial_orders(self) -> List[Order]:
        if self.task == "easy":
            return [
                Order(id="ORD_001", product="EcoPhone", quantity=5, due_date=20, reward=500.0),
                Order(id="ORD_002", product="GreenTab", quantity=3, due_date=30, reward=800.0)
            ]
        elif self.task == "medium":
            self.sea_blocked_until = 15
            return [
                Order(id="ORD_001", product="EcoPhone", quantity=8, due_date=15, reward=800.0),
                Order(id="ORD_002", product="GreenTab", quantity=5, due_date=25, reward=1200.0)
            ]
        else: # hard
            self.sea_blocked_until = 20
            self.carbon_quota = 800.0
            return [
                Order(id="ORD_001", product="EcoPhone", quantity=10, due_date=12, reward=1000.0),
                Order(id="ORD_002", product="GreenTab", quantity=10, due_date=20, reward=2000.0)
            ]

    def _get_obs(self, news: str = None) -> Observation:
        return Observation(
            step=self.step_count,
            inventory=self.inventory,
            active_shipments=self.active_shipments,
            pending_orders=self.pending_orders,
            carbon_total=self.carbon_total,
            carbon_quota=self.carbon_quota,
            cash_balance=self.cash_balance,
            news=news
        )

    def step(self, action: Action) -> Tuple[Observation, float, bool, dict]:
        reward = 0.0
        news = None
        info = {}

        # Day Advancement
        if action.action_type == ActionType.SKIP:
            self.step_count += 1
            if self.sea_blocked_until > 0 and self.step_count > self.sea_blocked_until:
                info["news"] = "Suez route is clear again."
                self.sea_blocked_until = 0

            # Process active shipments
            arrivals = []
            next_shipments = []
            for s in self.active_shipments:
                s.eta -= 1
                if s.eta <= 0:
                    arrivals.append(f"{s.quantity}x {s.part.value}")
                    self._receive_shipment(s)
                else:
                    next_shipments.append(s)
            self.active_shipments = next_shipments
            if arrivals:
                info["arrivals"] = arrivals

            # 3. Check Order Deadlines
            for order in self.pending_orders:
                if self.step_count > order.due_date:
                    reward -= 50.0  # Late penalty

        # 1. Process Actions (No Time Advancement)
        if action.action_type == ActionType.ORDER_PARTS:
            reward += self._handle_order_parts(action, info)
        elif action.action_type == ActionType.PRODUCE:
            reward += self._handle_production(action, info)
        elif action.action_type == ActionType.OFFSET:
            reward += self._handle_offset(action, info)
        elif action.action_type == ActionType.CANCEL:
            reward += self._handle_cancel(action, info)

        # 4. Check Termination
        if self.step_count >= 50 or not self.pending_orders:
            self.done = True
            info["final_score"] = self._calculate_final_score()

        return self._get_obs(news=news), reward, self.done, info

    def _handle_order_parts(self, action: Action, info: dict) -> float:
        if not action.part_type or not action.mode or not action.quantity:
            return -5.0
        
        base_cost = 10.0 * action.quantity
        mult, eta, carbon = self.TRANSPORT_SPECS[action.mode]
        
        # Check Disruption
        if action.mode == TransportMode.SEA and self.step_count < self.sea_blocked_until:
             info["error"] = "Suez Blocked: Sea routes unavailable"
             return -15.0

        total_cost = base_cost * mult
        
        if self.cash_balance < total_cost:
            info["error"] = "Insufficient funds"
            return -10.0
        
        self.cash_balance -= total_cost
        self.carbon_total += carbon * action.quantity
        
        merged = False
        for ship in self.active_shipments:
            if ship.part == action.part_type and ship.mode == action.mode and ship.eta == eta:
                ship.quantity += action.quantity
                ship.cost += total_cost
                ship.carbon_impact += carbon * action.quantity
                merged = True
                break
                
        if not merged:
            new_ship = Shipment(
                id=f"SHP_{random.randint(1000, 9999)}",
                part=action.part_type,
                quantity=action.quantity,
                mode=action.mode,
                eta=eta,
                carbon_impact=carbon * action.quantity,
                cost=total_cost
            )
            self.active_shipments.append(new_ship)
        return 2.0 

    def _handle_cancel(self, action: Action, info: dict) -> float:
        if not action.shipment_id: return 0.0
        
        for i, ship in enumerate(self.active_shipments):
            if ship.id == action.shipment_id:
                # Refund
                self.cash_balance += ship.cost
                self.carbon_total = max(0.0, self.carbon_total - ship.carbon_impact)
                self.active_shipments.pop(i)
                return 0.0
        return 0.0

    def _receive_shipment(self, ship: Shipment):
        current_val = getattr(self.inventory, ship.part.value)
        setattr(self.inventory, ship.part.value, current_val + ship.quantity)

    def _handle_production(self, action: Action, info: dict) -> float:
        if not action.product: return -5.0
        qty = action.quantity if action.quantity else 1
        
        # Determine part requirements depending on product. To keep simulation clean, both use chips.
        req_chips = qty
        req_sensors = qty if action.product == "EcoPhone" else 0
        req_batteries = qty if action.product == "GreenTab" else 0
        
        if self.inventory.chips >= req_chips and self.inventory.sensors >= req_sensors and self.inventory.batteries >= req_batteries:
            self.inventory.chips -= req_chips
            self.inventory.sensors -= req_sensors
            self.inventory.batteries -= req_batteries
            
            total_reward = 10.0 * qty
            remaining_produce = qty
            
            orders_to_remove = []
            for o in self.pending_orders:
                if o.product == action.product and remaining_produce > 0:
                    fulfilled = min(o.quantity, remaining_produce)
                    o.quantity -= fulfilled
                    remaining_produce -= fulfilled
                    if o.quantity <= 0:
                        orders_to_remove.append(o)
                        total_reward += o.reward
                        self.cash_balance += o.reward
                        
            for o in orders_to_remove:
                self.pending_orders.remove(o)
                
            return total_reward
        else:
            info["error"] = "Missing parts for run"
            return -10.0

    def _handle_offset(self, action: Action, info: dict) -> float:
        if not action.offset_amount: return -5.0
        if self.carbon_total <= 0:
            info["error"] = "No carbon footprint to offset"
            return -5.0
            
        cost = action.offset_amount * 2.0
        if self.cash_balance >= cost:
            self.cash_balance -= cost
            self.carbon_total = max(0.0, self.carbon_total - action.offset_amount)
            return 5.0
        else:
            info["error"] = "Insufficient funds for offset"
            return -10.0

    def _handle_reroute(self, action: Action, info: dict) -> float:
        return 0.0

    def _calculate_final_score(self) -> float:
        # Score is primarily based on order fulfillment, penalized by carbon overages
        total_orders = sum(o.quantity for o in self._generate_initial_orders())
        remaining = sum(o.quantity for o in self.pending_orders)
        
        fulfilled_ratio = (total_orders - remaining) / total_orders if total_orders > 0 else 1.0
        score = fulfilled_ratio
        
        # Penalize if carbon quota exceeded
        if self.carbon_total > self.carbon_quota:
            overage = self.carbon_total - self.carbon_quota
            penalty = (overage / self.carbon_quota) * 0.5
            score = max(0.0, score - penalty)
            
        return float(max(0.01, min(0.99, score)))