File size: 9,951 Bytes
9691f5e
 
 
 
 
 
 
c3fc8d4
9691f5e
b413222
 
9691f5e
 
 
 
 
 
7ed1454
9691f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c10dcd0
9691f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c10dcd0
 
 
 
 
9691f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
c10dcd0
9691f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c041c09
 
 
 
 
 
 
9691f5e
c041c09
9691f5e
 
 
 
c041c09
9691f5e
c041c09
9691f5e
c041c09
7ed1454
 
c10dcd0
 
 
 
9691f5e
 
 
 
 
 
c041c09
 
9691f5e
 
 
 
 
 
 
 
 
7ed1454
 
 
 
 
 
 
 
 
 
9691f5e
 
 
 
 
 
 
 
 
c041c09
 
7ed1454
 
9691f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b413222
 
c3fc8d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a9473a
c3fc8d4
 
 
 
 
 
 
 
b413222
 
 
 
 
 
 
 
 
 
2cee429
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))

from dataclasses import dataclass, asdict
from typing import List, Optional
import numpy as np
import httpx
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel

from config import (
    SIM_DAYS, HISTO_DAYS, LEAD_TIME,
    WRITE_OFF_RATE, WRITE_OFF_FREQUENCY,
)
from reward import compute_daily_pnl
from demand_environment import (
    GammaPoisson, GammaGammaHighVariance, SpikingDemand, SingleGammaLowVariance,
)
from demand_calculator import DemandCalculator
from order_processor import OrderProcessor
from performance_tracker import PerformanceTracker

app = FastAPI(title="Inventory Reasoning Environment")

ENV_TYPES = {
    0: GammaPoisson,
    1: GammaGammaHighVariance,
    2: SpikingDemand,
    3: SingleGammaLowVariance,
}


# ── Pydantic models (request/response) ───────────────────────────────────────

class InventoryAction(BaseModel):
    reorder_point: float
    reasoning: str = ""


class PendingOrder(BaseModel):
    arrival_day: int
    quantity: int


class InventoryObservation(BaseModel):
    day: int
    current_inventory: float
    demand_last_5: List[float]
    demand_mean_30d: float
    demand_std_30d: float
    fill_rate_so_far: float
    recent_stockouts: int
    recent_lost_sales: float
    days_remaining: int
    pending_orders: List[PendingOrder]
    demand_last_year_7d: List[float]


class StepResult(BaseModel):
    observation: InventoryObservation
    reward: float
    done: bool
    info: dict


class StateResponse(BaseModel):
    day: int
    fill_rate: float
    done: bool
    total_demand: float
    total_fulfilled: float
    stockouts: int
    lost_sales: float


# ── Episode state (single global episode for simplicity) ─────────────────────

class EpisodeState:
    def __init__(self):
        self.reset_state()

    def reset_state(self):
        self.day: int = 0
        self.inventory: float = 0.0
        self.demand_series: List[int] = []
        self.order_processor = OrderProcessor()
        self.performance_tracker = PerformanceTracker()
        self.total_demand: float = 0.0
        self.total_fulfilled: float = 0.0
        self.stockouts: int = 0
        self.lost_sales: float = 0.0
        self.initialized: bool = False

    def get_obs(self) -> InventoryObservation:
        hist_start = max(0, self.day - HISTO_DAYS)
        hist = self.demand_series[hist_start:self.day]
        last5 = self.demand_series[max(0, self.day - 5):self.day]
        hist30 = self.demand_series[max(0, self.day - 30):self.day]

        pending = [
            PendingOrder(arrival_day=o.arrival_day, quantity=o.quantity)
            for o in self.order_processor.order_queue[:5]
        ]

        ly_anchor = self.day - 365
        ly_start = max(0, ly_anchor - 3)
        ly_end = min(len(self.demand_series), ly_anchor + 4)
        demand_last_year_7d = [float(d) for d in self.demand_series[ly_start:ly_end]]

        return InventoryObservation(
            day=self.day,
            current_inventory=self.inventory,
            demand_last_5=[float(d) for d in last5],
            demand_mean_30d=float(np.mean(hist30)) if hist30 else 0.0,
            demand_std_30d=float(np.std(hist30)) if len(hist30) > 1 else 0.0,
            fill_rate_so_far=(
                self.total_fulfilled / self.total_demand
                if self.total_demand > 0 else 0.0
            ),
            recent_stockouts=self.stockouts,
            recent_lost_sales=self.lost_sales,
            days_remaining=SIM_DAYS - self.day,
            pending_orders=pending,
            demand_last_year_7d=demand_last_year_7d,
        )


episode = EpisodeState()


# ── Endpoints ─────────────────────────────────────────────────────────────────

@app.post("/reset", response_model=InventoryObservation)
def reset(env_type: int = 0):
    if env_type not in ENV_TYPES:
        raise HTTPException(status_code=400, detail=f"env_type must be 0-{len(ENV_TYPES)-1}")

    episode.reset_state()

    env_class = ENV_TYPES[env_type]
    environment = env_class(SIM_DAYS)
    dc = DemandCalculator(SIM_DAYS)
    dc.set_environment(environment)
    episode.demand_series = [dc.get_daily_demand(i) for i in range(SIM_DAYS)]

    # Warm up history (agents use HISTO_DAYS of history before acting)
    episode.day = HISTO_DAYS
    episode.initialized = True

    return episode.get_obs()


@app.post("/step", response_model=StepResult)
def step(action: InventoryAction):
    if not episode.initialized:
        raise HTTPException(status_code=400, detail="Call /reset before /step")
    if episode.day >= SIM_DAYS:
        raise HTTPException(status_code=400, detail="Episode already done. Call /reset.")

    day = episode.day
    demand = episode.demand_series[day]

    # 1. Deliver pending orders
    delivered = sum(
        o.quantity for o in episode.order_processor.order_queue
        if o.arrival_day == day
    )
    episode.inventory += delivered
    episode.order_processor.order_queue = [
        o for o in episode.order_processor.order_queue if o.arrival_day > day
    ]

    # 2. Daily spoilage (0.143% per day)
    spoilage = episode.inventory * WRITE_OFF_RATE
    episode.inventory = max(0.0, episode.inventory - spoilage)
    episode.performance_tracker.write_offs += spoilage

    # 3. Fulfill demand
    units_sold = min(demand, episode.inventory)
    episode.inventory = max(0.0, episode.inventory - demand)
    lost = max(0.0, demand - units_sold)
    if lost > 0:
        episode.stockouts += 1
    episode.lost_sales += lost
    episode.total_demand += demand
    episode.total_fulfilled += units_sold

    # 4. Reorder if inventory at or below ROP
    rop = max(0.0, action.reorder_point)
    qty = 0
    hist = episode.demand_series[max(0, day - 30):day]
    mean_demand = float(np.mean(hist)) if hist else 0.0
    pipeline = sum(o.quantity for o in episode.order_processor.order_queue)
    inv_position = episode.inventory + pipeline
    if day < SIM_DAYS - LEAD_TIME and inv_position <= rop:
        qty = max(0.0, rop - inv_position + mean_demand * LEAD_TIME)
        if qty > 0:
            episode.order_processor.place_order(day, int(qty))

    # 5. Track performance
    episode.performance_tracker.daily_performance(
        demand_quantity=demand,
        fulfilled_demand=int(units_sold),
        daily_writeoff=0,
    )

    episode.day += 1
    done = episode.day >= SIM_DAYS

    fill_rate = (
        episode.total_fulfilled / episode.total_demand
        if episode.total_demand > 0 else 0.0
    )

    pnl = compute_daily_pnl(
        units_sold=units_sold,
        lost=lost,
        inventory_after=episode.inventory,
        ordered_qty=qty,
        spoilage=spoilage,
        mean_demand=mean_demand,
    )
    reward = pnl["daily_reward"]

    return StepResult(
        observation=episode.get_obs(),
        reward=reward,
        done=done,
        info={
            "fill_rate": fill_rate,
            "stockouts": episode.stockouts,
            "lost_sales": episode.lost_sales,
            "inventory_in": delivered,
            "units_sold": units_sold,
            "daily_profit": pnl["daily_profit"],
            "daily_reward": pnl["daily_reward"],
            "reasoning_logged": action.reasoning[:200] if action.reasoning else "",
        },
    )


@app.get("/state", response_model=StateResponse)
def state():
    if not episode.initialized:
        raise HTTPException(status_code=400, detail="Call /reset first")
    fill_rate = (
        episode.total_fulfilled / episode.total_demand
        if episode.total_demand > 0 else 0.0
    )
    return StateResponse(
        day=episode.day,
        fill_rate=fill_rate,
        done=episode.day >= SIM_DAYS,
        total_demand=episode.total_demand,
        total_fulfilled=episode.total_fulfilled,
        stockouts=episode.stockouts,
        lost_sales=episode.lost_sales,
    )


# ── HF Inference API proxy (avoids browser CSP restrictions on HF Spaces) ────

class QwenRequest(BaseModel):
    model: str
    messages: list
    max_tokens: int = 600
    temperature: float = 0.7
    hf_token: str = ""

@app.post("/api/qwen", include_in_schema=False)
async def qwen_proxy(req: QwenRequest):
    token = req.hf_token or os.environ.get("HF_TOKEN", "")
    headers = {"Content-Type": "application/json"}
    if token:
        headers["Authorization"] = f"Bearer {token}"
    url = "https://router.huggingface.co/hf-inference/v1/chat/completions"
    payload = {"model": req.model, "messages": req.messages, "max_tokens": req.max_tokens, "temperature": req.temperature}
    async with httpx.AsyncClient(timeout=60.0) as client:
        resp = await client.post(url, json=payload, headers=headers)
    if resp.status_code != 200:
        raise HTTPException(status_code=resp.status_code, detail=resp.text)
    return resp.json()


# ── Serve React frontend (static files built by Dockerfile) ──────────────────
_static_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
if os.path.isdir(_static_dir):
    app.mount("/assets", StaticFiles(directory=os.path.join(_static_dir, "assets")), name="assets")

    @app.get("/", include_in_schema=False)
    @app.get("/{full_path:path}", include_in_schema=False)
    async def serve_spa(full_path: str = ""):
        # API routes are handled above; everything else serves the React app
        index = os.path.join(_static_dir, "index.html")
        return FileResponse(index, headers={"Cache-Control": "no-store, no-cache, must-revalidate"})