Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -79,4 +79,163 @@ async def load_model():
|
|
| 79 |
except Exception as e:
|
| 80 |
logger.warning(f"Predictor weights not loaded: {e}")
|
| 81 |
|
| 82 |
-
predictor =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
except Exception as e:
|
| 80 |
logger.warning(f"Predictor weights not loaded: {e}")
|
| 81 |
|
| 82 |
+
predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512)
|
| 83 |
+
logger.info("Kronos NQ+ES v2 ready.")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class CandleOut(BaseModel):
|
| 87 |
+
timestamp: str
|
| 88 |
+
open: float
|
| 89 |
+
high: float
|
| 90 |
+
low: float
|
| 91 |
+
close: float
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ForecastResponse(BaseModel):
|
| 95 |
+
instrument: str
|
| 96 |
+
timeframe: str
|
| 97 |
+
generated_at: str
|
| 98 |
+
historical: list[CandleOut]
|
| 99 |
+
forecast_mean: list[CandleOut]
|
| 100 |
+
forecast_upper: list[CandleOut]
|
| 101 |
+
forecast_lower: list[CandleOut]
|
| 102 |
+
direction: str
|
| 103 |
+
confidence: float
|
| 104 |
+
volatility_ratio: float
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def fetch_candles(ticker: str, interval: str, period: str) -> pd.DataFrame:
|
| 108 |
+
raw = yf.download(ticker, period=period, interval=interval, progress=False)
|
| 109 |
+
if raw.empty:
|
| 110 |
+
raise HTTPException(status_code=502, detail=f"No data for {ticker}")
|
| 111 |
+
df = raw.reset_index()
|
| 112 |
+
if hasattr(df.columns, 'levels'):
|
| 113 |
+
df.columns = [c[0] if isinstance(c, tuple) else c for c in df.columns]
|
| 114 |
+
rename = {c: c.lower() for c in df.columns}
|
| 115 |
+
df.rename(columns=rename, inplace=True)
|
| 116 |
+
if "datetime" in df.columns:
|
| 117 |
+
df.rename(columns={"datetime": "timestamp"}, inplace=True)
|
| 118 |
+
elif "date" in df.columns:
|
| 119 |
+
df.rename(columns={"date": "timestamp"}, inplace=True)
|
| 120 |
+
df["timestamp"] = pd.to_datetime(df["timestamp"])
|
| 121 |
+
return df[["timestamp", "open", "high", "low", "close"]].dropna()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def run_forecast(df: pd.DataFrame, forecast_bars: int, n_samples: int = 10):
|
| 125 |
+
lookback = min(len(df), 400)
|
| 126 |
+
x_df = df.tail(lookback).reset_index(drop=True)
|
| 127 |
+
|
| 128 |
+
freq = pd.infer_freq(x_df["timestamp"])
|
| 129 |
+
if freq is None:
|
| 130 |
+
delta = x_df["timestamp"].iloc[-1] - x_df["timestamp"].iloc[-2]
|
| 131 |
+
future_ts = pd.Series([x_df["timestamp"].iloc[-1] + delta * (i + 1) for i in range(forecast_bars)])
|
| 132 |
+
else:
|
| 133 |
+
future_ts = pd.Series(pd.date_range(
|
| 134 |
+
start=x_df["timestamp"].iloc[-1],
|
| 135 |
+
periods=forecast_bars + 1,
|
| 136 |
+
freq=freq,
|
| 137 |
+
)[1:])
|
| 138 |
+
|
| 139 |
+
samples = []
|
| 140 |
+
for _ in range(n_samples):
|
| 141 |
+
pred_df = predictor.predict(
|
| 142 |
+
df=x_df[["open", "high", "low", "close"]],
|
| 143 |
+
x_timestamp=x_df["timestamp"],
|
| 144 |
+
y_timestamp=future_ts,
|
| 145 |
+
pred_len=forecast_bars,
|
| 146 |
+
T=0.3,
|
| 147 |
+
top_p=0.5,
|
| 148 |
+
sample_count=1,
|
| 149 |
+
)
|
| 150 |
+
samples.append(pred_df[["open", "high", "low", "close"]].values)
|
| 151 |
+
|
| 152 |
+
samples = np.array(samples)
|
| 153 |
+
mean = samples.mean(axis=0)
|
| 154 |
+
upper = np.percentile(samples, 90, axis=0)
|
| 155 |
+
lower = np.percentile(samples, 10, axis=0)
|
| 156 |
+
|
| 157 |
+
return mean, upper, lower, future_ts
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def calc_direction(mean_candles: np.ndarray, last_close: float):
|
| 161 |
+
final_close = mean_candles[-1, 3]
|
| 162 |
+
pct_change = (final_close - last_close) / last_close * 100
|
| 163 |
+
|
| 164 |
+
if pct_change > 0.10:
|
| 165 |
+
return "BULLISH", min(abs(pct_change) * 30, 95)
|
| 166 |
+
elif pct_change < -0.10:
|
| 167 |
+
return "BEARISH", min(abs(pct_change) * 30, 95)
|
| 168 |
+
else:
|
| 169 |
+
return "NEUTRAL", max(50 - abs(pct_change) * 150, 10)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def calc_vol_ratio(mean_candles: np.ndarray, hist_df: pd.DataFrame):
|
| 173 |
+
pred_ranges = mean_candles[:, 1] - mean_candles[:, 2]
|
| 174 |
+
hist_ranges = (hist_df["high"] - hist_df["low"]).tail(len(mean_candles)).values
|
| 175 |
+
if hist_ranges.mean() == 0:
|
| 176 |
+
return 1.0
|
| 177 |
+
return float(pred_ranges.mean() / hist_ranges.mean())
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def candles_to_list(arr: np.ndarray, timestamps) -> list[CandleOut]:
|
| 181 |
+
out = []
|
| 182 |
+
for i, ts in enumerate(timestamps):
|
| 183 |
+
out.append(CandleOut(
|
| 184 |
+
timestamp=str(ts),
|
| 185 |
+
open=round(float(arr[i, 0]), 2),
|
| 186 |
+
high=round(float(arr[i, 1]), 2),
|
| 187 |
+
low=round(float(arr[i, 2]), 2),
|
| 188 |
+
close=round(float(arr[i, 3]), 2),
|
| 189 |
+
))
|
| 190 |
+
return out
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@app.get("/forecast", response_model=ForecastResponse)
|
| 194 |
+
async def get_forecast(
|
| 195 |
+
instrument: str = Query("NQ", pattern="^(NQ|ES)$"),
|
| 196 |
+
timeframe: str = Query("1h", pattern="^(5m|1h)$"),
|
| 197 |
+
):
|
| 198 |
+
if predictor is None:
|
| 199 |
+
raise HTTPException(status_code=503, detail="Model still loading")
|
| 200 |
+
|
| 201 |
+
ticker = TICKER_MAP[instrument]
|
| 202 |
+
tf_cfg = TIMEFRAME_MAP[timeframe]
|
| 203 |
+
|
| 204 |
+
df = fetch_candles(ticker, tf_cfg["interval"], tf_cfg["period"])
|
| 205 |
+
logger.info(f"Fetched {len(df)} candles for {instrument} @ {timeframe}")
|
| 206 |
+
|
| 207 |
+
mean, upper, lower, future_ts = run_forecast(df, tf_cfg["forecast_bars"])
|
| 208 |
+
|
| 209 |
+
last_close = float(df["close"].iloc[-1])
|
| 210 |
+
direction, confidence = calc_direction(mean, last_close)
|
| 211 |
+
vol_ratio = calc_vol_ratio(mean, df)
|
| 212 |
+
|
| 213 |
+
hist_tail = df.tail(50)
|
| 214 |
+
historical = [
|
| 215 |
+
CandleOut(
|
| 216 |
+
timestamp=str(row.timestamp),
|
| 217 |
+
open=round(float(row.open), 2),
|
| 218 |
+
high=round(float(row.high), 2),
|
| 219 |
+
low=round(float(row.low), 2),
|
| 220 |
+
close=round(float(row.close), 2),
|
| 221 |
+
)
|
| 222 |
+
for row in hist_tail.itertuples()
|
| 223 |
+
]
|
| 224 |
+
|
| 225 |
+
return ForecastResponse(
|
| 226 |
+
instrument=instrument,
|
| 227 |
+
timeframe=timeframe,
|
| 228 |
+
generated_at=datetime.utcnow().isoformat() + "Z",
|
| 229 |
+
historical=historical,
|
| 230 |
+
forecast_mean=candles_to_list(mean, future_ts),
|
| 231 |
+
forecast_upper=candles_to_list(upper, future_ts),
|
| 232 |
+
forecast_lower=candles_to_list(lower, future_ts),
|
| 233 |
+
direction=direction,
|
| 234 |
+
confidence=round(confidence, 1),
|
| 235 |
+
volatility_ratio=round(vol_ratio, 2),
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@app.get("/health")
|
| 240 |
+
async def health():
|
| 241 |
+
return {"status": "ok", "model_loaded": predictor is not None}
|