ifieryarrows commited on
Commit
af722fe
·
verified ·
1 Parent(s): c081d45

Sync from GitHub (tests passed)

Browse files
Files changed (1) hide show
  1. deep_learning/models/tft_copper.py +30 -15
deep_learning/models/tft_copper.py CHANGED
@@ -251,43 +251,58 @@ def format_prediction(
251
  n_days = pred.shape[0]
252
  median_idx = len(quantiles) // 2
253
 
254
- # Per-day forecast array
 
 
255
  daily_forecasts = []
 
 
 
 
 
 
256
  for d in range(n_days):
257
  med = float(pred[d, median_idx])
258
  q10 = float(pred[d, 1]) if len(quantiles) > 2 else med
259
  q90 = float(pred[d, -2]) if len(quantiles) > 2 else med
260
  q02 = float(pred[d, 0])
261
  q98 = float(pred[d, -1])
 
 
 
 
 
 
 
 
 
262
  daily_forecasts.append({
263
  "day": d + 1,
264
- "return_median": med,
265
- "return_q10": q10,
266
- "return_q90": q90,
267
- "price_median": baseline_price * (1 + med),
268
- "price_q10": baseline_price * (1 + q10),
269
- "price_q90": baseline_price * (1 + q90),
270
- "price_q02": baseline_price * (1 + q02),
271
- "price_q98": baseline_price * (1 + q98),
272
  })
273
 
274
  # T+1 is the primary signal (most reliable, highest signal-to-noise).
275
- # T+5 (end-of-horizon) provides the weekly trend direction.
276
  first = daily_forecasts[0]
277
  last = daily_forecasts[-1]
278
- vol_estimate = (first["return_q90"] - first["return_q10"]) / 2.0
279
 
280
  return {
281
- "predicted_return_median": first["return_median"],
282
- "predicted_return_q10": first["return_q10"],
283
- "predicted_return_q90": first["return_q90"],
284
  "predicted_price_median": first["price_median"],
285
  "predicted_price_q10": first["price_q10"],
286
  "predicted_price_q90": first["price_q90"],
287
  "confidence_band_96": (first["price_q02"], first["price_q98"]),
288
  "volatility_estimate": vol_estimate,
289
  "quantiles": {f"q{q:.2f}": float(pred[0, i]) for i, q in enumerate(quantiles)},
290
- "weekly_return": last["return_median"],
291
  "weekly_price": last["price_median"],
292
  "prediction_horizon_days": n_days,
293
  "daily_forecasts": daily_forecasts,
 
251
  n_days = pred.shape[0]
252
  median_idx = len(quantiles) // 2
253
 
254
+ # Each output step is the predicted *daily* return for that day.
255
+ # Prices are compounded: price_T+d = price_T+d-1 * (1 + return_d).
256
+ # Cumulative return from today: product of (1+r_i) for i=1..d, minus 1.
257
  daily_forecasts = []
258
+ cum_price_med = baseline_price
259
+ cum_price_q10 = baseline_price
260
+ cum_price_q90 = baseline_price
261
+ cum_price_q02 = baseline_price
262
+ cum_price_q98 = baseline_price
263
+
264
  for d in range(n_days):
265
  med = float(pred[d, median_idx])
266
  q10 = float(pred[d, 1]) if len(quantiles) > 2 else med
267
  q90 = float(pred[d, -2]) if len(quantiles) > 2 else med
268
  q02 = float(pred[d, 0])
269
  q98 = float(pred[d, -1])
270
+
271
+ cum_price_med *= (1 + med)
272
+ cum_price_q10 *= (1 + q10)
273
+ cum_price_q90 *= (1 + q90)
274
+ cum_price_q02 *= (1 + q02)
275
+ cum_price_q98 *= (1 + q98)
276
+
277
+ cum_return = (cum_price_med / baseline_price) - 1.0
278
+
279
  daily_forecasts.append({
280
  "day": d + 1,
281
+ "daily_return": med,
282
+ "cumulative_return": cum_return,
283
+ "price_median": cum_price_med,
284
+ "price_q10": cum_price_q10,
285
+ "price_q90": cum_price_q90,
286
+ "price_q02": cum_price_q02,
287
+ "price_q98": cum_price_q98,
 
288
  })
289
 
290
  # T+1 is the primary signal (most reliable, highest signal-to-noise).
 
291
  first = daily_forecasts[0]
292
  last = daily_forecasts[-1]
293
+ vol_estimate = (first["price_q90"] - first["price_q10"]) / (2.0 * baseline_price)
294
 
295
  return {
296
+ "predicted_return_median": first["daily_return"],
297
+ "predicted_return_q10": float(pred[0, 1]) if len(quantiles) > 2 else first["daily_return"],
298
+ "predicted_return_q90": float(pred[0, -2]) if len(quantiles) > 2 else first["daily_return"],
299
  "predicted_price_median": first["price_median"],
300
  "predicted_price_q10": first["price_q10"],
301
  "predicted_price_q90": first["price_q90"],
302
  "confidence_band_96": (first["price_q02"], first["price_q98"]),
303
  "volatility_estimate": vol_estimate,
304
  "quantiles": {f"q{q:.2f}": float(pred[0, i]) for i, q in enumerate(quantiles)},
305
+ "weekly_return": last["cumulative_return"],
306
  "weekly_price": last["price_median"],
307
  "prediction_horizon_days": n_days,
308
  "daily_forecasts": daily_forecasts,