Jenak5 commited on
Commit
7fd312d
·
verified ·
1 Parent(s): 186d38d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -1
app.py CHANGED
@@ -79,4 +79,163 @@ async def load_model():
79
  except Exception as e:
80
  logger.warning(f"Predictor weights not loaded: {e}")
81
 
82
- predictor = Kr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  except Exception as e:
80
  logger.warning(f"Predictor weights not loaded: {e}")
81
 
82
+ predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512)
83
+ logger.info("Kronos NQ+ES v2 ready.")
84
+
85
+
86
+ class CandleOut(BaseModel):
87
+ timestamp: str
88
+ open: float
89
+ high: float
90
+ low: float
91
+ close: float
92
+
93
+
94
+ class ForecastResponse(BaseModel):
95
+ instrument: str
96
+ timeframe: str
97
+ generated_at: str
98
+ historical: list[CandleOut]
99
+ forecast_mean: list[CandleOut]
100
+ forecast_upper: list[CandleOut]
101
+ forecast_lower: list[CandleOut]
102
+ direction: str
103
+ confidence: float
104
+ volatility_ratio: float
105
+
106
+
107
+ def fetch_candles(ticker: str, interval: str, period: str) -> pd.DataFrame:
108
+ raw = yf.download(ticker, period=period, interval=interval, progress=False)
109
+ if raw.empty:
110
+ raise HTTPException(status_code=502, detail=f"No data for {ticker}")
111
+ df = raw.reset_index()
112
+ if hasattr(df.columns, 'levels'):
113
+ df.columns = [c[0] if isinstance(c, tuple) else c for c in df.columns]
114
+ rename = {c: c.lower() for c in df.columns}
115
+ df.rename(columns=rename, inplace=True)
116
+ if "datetime" in df.columns:
117
+ df.rename(columns={"datetime": "timestamp"}, inplace=True)
118
+ elif "date" in df.columns:
119
+ df.rename(columns={"date": "timestamp"}, inplace=True)
120
+ df["timestamp"] = pd.to_datetime(df["timestamp"])
121
+ return df[["timestamp", "open", "high", "low", "close"]].dropna()
122
+
123
+
124
+ def run_forecast(df: pd.DataFrame, forecast_bars: int, n_samples: int = 10):
125
+ lookback = min(len(df), 400)
126
+ x_df = df.tail(lookback).reset_index(drop=True)
127
+
128
+ freq = pd.infer_freq(x_df["timestamp"])
129
+ if freq is None:
130
+ delta = x_df["timestamp"].iloc[-1] - x_df["timestamp"].iloc[-2]
131
+ future_ts = pd.Series([x_df["timestamp"].iloc[-1] + delta * (i + 1) for i in range(forecast_bars)])
132
+ else:
133
+ future_ts = pd.Series(pd.date_range(
134
+ start=x_df["timestamp"].iloc[-1],
135
+ periods=forecast_bars + 1,
136
+ freq=freq,
137
+ )[1:])
138
+
139
+ samples = []
140
+ for _ in range(n_samples):
141
+ pred_df = predictor.predict(
142
+ df=x_df[["open", "high", "low", "close"]],
143
+ x_timestamp=x_df["timestamp"],
144
+ y_timestamp=future_ts,
145
+ pred_len=forecast_bars,
146
+ T=0.3,
147
+ top_p=0.5,
148
+ sample_count=1,
149
+ )
150
+ samples.append(pred_df[["open", "high", "low", "close"]].values)
151
+
152
+ samples = np.array(samples)
153
+ mean = samples.mean(axis=0)
154
+ upper = np.percentile(samples, 90, axis=0)
155
+ lower = np.percentile(samples, 10, axis=0)
156
+
157
+ return mean, upper, lower, future_ts
158
+
159
+
160
+ def calc_direction(mean_candles: np.ndarray, last_close: float):
161
+ final_close = mean_candles[-1, 3]
162
+ pct_change = (final_close - last_close) / last_close * 100
163
+
164
+ if pct_change > 0.10:
165
+ return "BULLISH", min(abs(pct_change) * 30, 95)
166
+ elif pct_change < -0.10:
167
+ return "BEARISH", min(abs(pct_change) * 30, 95)
168
+ else:
169
+ return "NEUTRAL", max(50 - abs(pct_change) * 150, 10)
170
+
171
+
172
+ def calc_vol_ratio(mean_candles: np.ndarray, hist_df: pd.DataFrame):
173
+ pred_ranges = mean_candles[:, 1] - mean_candles[:, 2]
174
+ hist_ranges = (hist_df["high"] - hist_df["low"]).tail(len(mean_candles)).values
175
+ if hist_ranges.mean() == 0:
176
+ return 1.0
177
+ return float(pred_ranges.mean() / hist_ranges.mean())
178
+
179
+
180
+ def candles_to_list(arr: np.ndarray, timestamps) -> list[CandleOut]:
181
+ out = []
182
+ for i, ts in enumerate(timestamps):
183
+ out.append(CandleOut(
184
+ timestamp=str(ts),
185
+ open=round(float(arr[i, 0]), 2),
186
+ high=round(float(arr[i, 1]), 2),
187
+ low=round(float(arr[i, 2]), 2),
188
+ close=round(float(arr[i, 3]), 2),
189
+ ))
190
+ return out
191
+
192
+
193
+ @app.get("/forecast", response_model=ForecastResponse)
194
+ async def get_forecast(
195
+ instrument: str = Query("NQ", pattern="^(NQ|ES)$"),
196
+ timeframe: str = Query("1h", pattern="^(5m|1h)$"),
197
+ ):
198
+ if predictor is None:
199
+ raise HTTPException(status_code=503, detail="Model still loading")
200
+
201
+ ticker = TICKER_MAP[instrument]
202
+ tf_cfg = TIMEFRAME_MAP[timeframe]
203
+
204
+ df = fetch_candles(ticker, tf_cfg["interval"], tf_cfg["period"])
205
+ logger.info(f"Fetched {len(df)} candles for {instrument} @ {timeframe}")
206
+
207
+ mean, upper, lower, future_ts = run_forecast(df, tf_cfg["forecast_bars"])
208
+
209
+ last_close = float(df["close"].iloc[-1])
210
+ direction, confidence = calc_direction(mean, last_close)
211
+ vol_ratio = calc_vol_ratio(mean, df)
212
+
213
+ hist_tail = df.tail(50)
214
+ historical = [
215
+ CandleOut(
216
+ timestamp=str(row.timestamp),
217
+ open=round(float(row.open), 2),
218
+ high=round(float(row.high), 2),
219
+ low=round(float(row.low), 2),
220
+ close=round(float(row.close), 2),
221
+ )
222
+ for row in hist_tail.itertuples()
223
+ ]
224
+
225
+ return ForecastResponse(
226
+ instrument=instrument,
227
+ timeframe=timeframe,
228
+ generated_at=datetime.utcnow().isoformat() + "Z",
229
+ historical=historical,
230
+ forecast_mean=candles_to_list(mean, future_ts),
231
+ forecast_upper=candles_to_list(upper, future_ts),
232
+ forecast_lower=candles_to_list(lower, future_ts),
233
+ direction=direction,
234
+ confidence=round(confidence, 1),
235
+ volatility_ratio=round(vol_ratio, 2),
236
+ )
237
+
238
+
239
+ @app.get("/health")
240
+ async def health():
241
+ return {"status": "ok", "model_loaded": predictor is not None}