jtassos2025's picture
Initial commit on new main
97ff2d3
import os
import joblib
import torch
import json
import numpy as np
import pandas as pd
import gradio as gr
from utils.retrain import DualLSTM, add_features
from sklearn.preprocessing import StandardScaler
# ==========================================================
# Device setup
# ==========================================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Paths
REPO_DIR = os.path.dirname(__file__)
MODELS_DIR = os.path.join(REPO_DIR, "models")
DATA_PATH = os.path.join(REPO_DIR, "data/eth_combined_updated.csv")
# Monte Carlo dropout passes
MC_DROPOUT_PASSES = 100
# Horizon options
HORIZONS = {
"Daily": "daily",
"Weekly": "weekly",
"4-Weeks": "4-weeks",
"8-Weeks": "8-weeks",
"12-Weeks": "12-weeks"
}
# ==========================================================
# Safe torch load fix for PyTorch 2.6+
# ==========================================================
from sklearn.preprocessing import StandardScaler
try:
torch.serialization.add_safe_globals([StandardScaler])
except Exception:
pass
# ==========================================================
# Predict function
# ==========================================================
def predict_eth(horizon_label):
horizon = HORIZONS[horizon_label]
model_path = os.path.join(MODELS_DIR, f"{horizon}_model.pth")
scaler_path = os.path.join(MODELS_DIR, f"{horizon}_scaler.pkl")
if not (os.path.exists(model_path) and os.path.exists(scaler_path)):
return json.dumps({
"error": f"Missing files for {horizon_label}. Ensure {horizon}_model.pth and {horizon}_scaler.pkl exist."
}, indent=2)
# Load checkpoint
try:
checkpoint = torch.load(model_path, map_location=DEVICE)
except Exception:
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
# Load model
model = DualLSTM(
input_size=len(checkpoint["features"]),
hidden_size=checkpoint["hidden_size"],
num_layers=checkpoint["num_layers"],
dropout=checkpoint["dropout"]
).to(DEVICE)
model.load_state_dict(checkpoint["model_state"])
model.train() # keep dropout active
# Load scaler
scaler = joblib.load(scaler_path)
# Load data
df = pd.read_csv(DATA_PATH)
df.columns = df.columns.str.lower().str.replace(" ", "_")
df = df[["open", "high", "low", "close", "start", "end"]].copy()
df = add_features(df)
# Record last known date
last_recorded_date = None
if "start" in df.columns:
last_recorded_date = pd.to_datetime(df["start"].iloc[-1])
elif "end" in df.columns:
last_recorded_date = pd.to_datetime(df["end"].iloc[-1])
# Scale
features = checkpoint["features"]
data_scaled = scaler.transform(pd.DataFrame(df[features].values, columns=features))
# Prepare sequences
seq_short = torch.tensor(data_scaled[-checkpoint["seq_len_short"]:], dtype=torch.float32).unsqueeze(0).to(DEVICE)
seq_long = torch.tensor(data_scaled[-checkpoint["seq_len_long"]:], dtype=torch.float32).unsqueeze(0).to(DEVICE)
# MC Dropout predictions
preds = []
with torch.no_grad():
for _ in range(MC_DROPOUT_PASSES):
pred = model(seq_short, seq_long).cpu().numpy().flatten()[0]
preds.append(pred)
# === Compute stats ===
preds = np.array(preds)
pred_mean = float(preds.mean())
pred_std = float(preds.std())
last_close = float(df["close"].iloc[-1])
predicted_close = last_close * (1 + pred_mean)
confidence = float(np.exp(-pred_std * 100))
# Derived info
predicted_min_close = predicted_close * (1 - pred_std) # conservative lower bound
predicted_max_close = predicted_close * (1 + pred_std) # optimistic upper bound
# Predict next date
if last_recorded_date is not None:
if horizon == "daily":
predicted_date = (last_recorded_date + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
elif horizon == "weekly":
predicted_date = (last_recorded_date + pd.Timedelta(days=7)).strftime("%Y-%m-%d")
elif horizon == "4-weeks":
predicted_date = (last_recorded_date + pd.Timedelta(days=28)).strftime("%Y-%m-%d")
elif horizon == "8-weeks":
predicted_date = (last_recorded_date + pd.Timedelta(days=56)).strftime("%Y-%m-%d")
elif horizon == "12-weeks":
predicted_date = (last_recorded_date + pd.Timedelta(days=84)).strftime("%Y-%m-%d")
else:
predicted_date = None
else:
predicted_date = None
# === Final result JSON ===
result = {
"horizon": horizon_label,
"last_recorded_date": last_recorded_date.strftime("%Y-%m-%d") if last_recorded_date is not None else None,
"predicted_date": predicted_date,
"last_recorded_close": round(last_close, 2),
"predicted_return_mean": round(pred_mean * 100, 3),
"predicted_return_std": round(pred_std * 100, 3),
"predicted_next_close": round(predicted_close, 2),
"predicted_close_minimum": round(predicted_min_close, 2),
"predicted_close_maximum": round(predicted_max_close, 2),
"confidence": round(confidence, 4)
}
return json.dumps(result, indent=2, ensure_ascii=False)
# ==========================================================
# Gradio UI
# ==========================================================
with gr.Blocks(title="ETH Return Forecast") as demo:
gr.Markdown("# Ethereum Return Forecasting Model")
gr.Markdown("Select a forecast horizon and get predictions as JSON output:")
# Dynamic text elements
last_date_display = gr.Markdown("**Last Recorded Date:** —")
predicted_date_display = gr.Markdown("**Predicted Date:** —")
predicted_return_display = gr.Markdown("**Predicted Return:** —")
predicted_close_range_display = gr.Markdown("**Predicted Close Range:** —")
# Dropdown for user to select forecast horizon
horizon_dropdown = gr.Dropdown(
choices=list(HORIZONS.keys()),
value="Daily",
label="Forecast Horizon",
)
# Output JSON box
output_box = gr.Code(label="JSON Output", language="json")
predict_btn = gr.Button("Predict")
# === Define a wrapper that returns multiple outputs ===
def run_prediction(horizon_label):
result_json = predict_eth(horizon_label)
try:
result = json.loads(result_json)
last_date = result.get("last_recorded_date", "N/A")
pred_date = result.get("predicted_date", "N/A")
predicted_return = f"{result.get('predicted_return_mean', 0):.2f}% ± {result.get('predicted_return_std', 0):.2f}%"
predicted_close_range = f"${result.get('predicted_close_minimum', 0):,.2f} - ${result.get('predicted_close_maximum', 0):,.2f}"
except Exception:
last_date, pred_date = "N/A", "N/A"
predicted_return, predicted_close_range = "N/A", "N/A"
return (
f"**Last Recorded Date:** {last_date}",
f"**Predicted Date:** {pred_date}",
f"**Predicted Return:** {predicted_return}",
f"**Predicted Close Range:** {predicted_close_range}",
result_json
)
# Connect button click
predict_btn.click(
fn=run_prediction,
inputs=horizon_dropdown,
outputs=[last_date_display, predicted_date_display, predicted_return_display, predicted_close_range_display, output_box],
)
demo.launch()