Spaces:
Running
Running
Sync from GitHub (tests passed)
Browse files
deep_learning/models/tft_copper.py
CHANGED
|
@@ -251,43 +251,58 @@ def format_prediction(
|
|
| 251 |
n_days = pred.shape[0]
|
| 252 |
median_idx = len(quantiles) // 2
|
| 253 |
|
| 254 |
-
#
|
|
|
|
|
|
|
| 255 |
daily_forecasts = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
for d in range(n_days):
|
| 257 |
med = float(pred[d, median_idx])
|
| 258 |
q10 = float(pred[d, 1]) if len(quantiles) > 2 else med
|
| 259 |
q90 = float(pred[d, -2]) if len(quantiles) > 2 else med
|
| 260 |
q02 = float(pred[d, 0])
|
| 261 |
q98 = float(pred[d, -1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
daily_forecasts.append({
|
| 263 |
"day": d + 1,
|
| 264 |
-
"
|
| 265 |
-
"
|
| 266 |
-
"
|
| 267 |
-
"
|
| 268 |
-
"
|
| 269 |
-
"
|
| 270 |
-
"
|
| 271 |
-
"price_q98": baseline_price * (1 + q98),
|
| 272 |
})
|
| 273 |
|
| 274 |
# T+1 is the primary signal (most reliable, highest signal-to-noise).
|
| 275 |
-
# T+5 (end-of-horizon) provides the weekly trend direction.
|
| 276 |
first = daily_forecasts[0]
|
| 277 |
last = daily_forecasts[-1]
|
| 278 |
-
vol_estimate = (first["
|
| 279 |
|
| 280 |
return {
|
| 281 |
-
"predicted_return_median": first["
|
| 282 |
-
"predicted_return_q10": first["
|
| 283 |
-
"predicted_return_q90": first["
|
| 284 |
"predicted_price_median": first["price_median"],
|
| 285 |
"predicted_price_q10": first["price_q10"],
|
| 286 |
"predicted_price_q90": first["price_q90"],
|
| 287 |
"confidence_band_96": (first["price_q02"], first["price_q98"]),
|
| 288 |
"volatility_estimate": vol_estimate,
|
| 289 |
"quantiles": {f"q{q:.2f}": float(pred[0, i]) for i, q in enumerate(quantiles)},
|
| 290 |
-
"weekly_return": last["
|
| 291 |
"weekly_price": last["price_median"],
|
| 292 |
"prediction_horizon_days": n_days,
|
| 293 |
"daily_forecasts": daily_forecasts,
|
|
|
|
| 251 |
n_days = pred.shape[0]
|
| 252 |
median_idx = len(quantiles) // 2
|
| 253 |
|
| 254 |
+
# Each output step is the predicted *daily* return for that day.
|
| 255 |
+
# Prices are compounded: price_T+d = price_T+d-1 * (1 + return_d).
|
| 256 |
+
# Cumulative return from today: product of (1+r_i) for i=1..d, minus 1.
|
| 257 |
daily_forecasts = []
|
| 258 |
+
cum_price_med = baseline_price
|
| 259 |
+
cum_price_q10 = baseline_price
|
| 260 |
+
cum_price_q90 = baseline_price
|
| 261 |
+
cum_price_q02 = baseline_price
|
| 262 |
+
cum_price_q98 = baseline_price
|
| 263 |
+
|
| 264 |
for d in range(n_days):
|
| 265 |
med = float(pred[d, median_idx])
|
| 266 |
q10 = float(pred[d, 1]) if len(quantiles) > 2 else med
|
| 267 |
q90 = float(pred[d, -2]) if len(quantiles) > 2 else med
|
| 268 |
q02 = float(pred[d, 0])
|
| 269 |
q98 = float(pred[d, -1])
|
| 270 |
+
|
| 271 |
+
cum_price_med *= (1 + med)
|
| 272 |
+
cum_price_q10 *= (1 + q10)
|
| 273 |
+
cum_price_q90 *= (1 + q90)
|
| 274 |
+
cum_price_q02 *= (1 + q02)
|
| 275 |
+
cum_price_q98 *= (1 + q98)
|
| 276 |
+
|
| 277 |
+
cum_return = (cum_price_med / baseline_price) - 1.0
|
| 278 |
+
|
| 279 |
daily_forecasts.append({
|
| 280 |
"day": d + 1,
|
| 281 |
+
"daily_return": med,
|
| 282 |
+
"cumulative_return": cum_return,
|
| 283 |
+
"price_median": cum_price_med,
|
| 284 |
+
"price_q10": cum_price_q10,
|
| 285 |
+
"price_q90": cum_price_q90,
|
| 286 |
+
"price_q02": cum_price_q02,
|
| 287 |
+
"price_q98": cum_price_q98,
|
|
|
|
| 288 |
})
|
| 289 |
|
| 290 |
# T+1 is the primary signal (most reliable, highest signal-to-noise).
|
|
|
|
| 291 |
first = daily_forecasts[0]
|
| 292 |
last = daily_forecasts[-1]
|
| 293 |
+
vol_estimate = (first["price_q90"] - first["price_q10"]) / (2.0 * baseline_price)
|
| 294 |
|
| 295 |
return {
|
| 296 |
+
"predicted_return_median": first["daily_return"],
|
| 297 |
+
"predicted_return_q10": float(pred[0, 1]) if len(quantiles) > 2 else first["daily_return"],
|
| 298 |
+
"predicted_return_q90": float(pred[0, -2]) if len(quantiles) > 2 else first["daily_return"],
|
| 299 |
"predicted_price_median": first["price_median"],
|
| 300 |
"predicted_price_q10": first["price_q10"],
|
| 301 |
"predicted_price_q90": first["price_q90"],
|
| 302 |
"confidence_band_96": (first["price_q02"], first["price_q98"]),
|
| 303 |
"volatility_estimate": vol_estimate,
|
| 304 |
"quantiles": {f"q{q:.2f}": float(pred[0, i]) for i, q in enumerate(quantiles)},
|
| 305 |
+
"weekly_return": last["cumulative_return"],
|
| 306 |
"weekly_price": last["price_median"],
|
| 307 |
"prediction_horizon_days": n_days,
|
| 308 |
"daily_forecasts": daily_forecasts,
|