Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,91 +6,104 @@ import plotly.graph_objects as go
|
|
| 6 |
from plotly.subplots import make_subplots
|
| 7 |
import timesfm
|
| 8 |
|
| 9 |
-
# ---
|
| 10 |
-
#
|
| 11 |
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
|
| 12 |
"google/timesfm-2.5-200m-pytorch"
|
| 13 |
)
|
| 14 |
|
| 15 |
-
#
|
| 16 |
model.compile(timesfm.ForecastConfig(
|
| 17 |
-
max_context=
|
| 18 |
-
max_horizon=
|
| 19 |
normalize_inputs=True,
|
| 20 |
infer_is_positive=True
|
| 21 |
))
|
| 22 |
|
| 23 |
-
def
|
| 24 |
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
|
| 25 |
-
vertical_spacing=0.
|
| 26 |
|
| 27 |
-
|
| 28 |
|
| 29 |
-
#
|
| 30 |
-
fig.add_trace(go.Scatter(x=
|
|
|
|
| 31 |
|
| 32 |
-
#
|
| 33 |
if is_backtest:
|
| 34 |
-
fig.add_trace(go.Scatter(x=df.index[-30:], y=df['Close'][-30:], name='Actual
|
|
|
|
| 35 |
|
| 36 |
-
#
|
| 37 |
-
fc_dates = pd.date_range(start=
|
| 38 |
-
fig.add_trace(go.Scatter(x=fc_dates, y=
|
|
|
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
fig.add_trace(go.Bar(x=df.index, y=df['Volume'], name='Volume',
|
|
|
|
| 42 |
|
| 43 |
-
fig.update_layout(template='plotly_dark', paper_bgcolor='#131722', plot_bgcolor='#131722',
|
|
|
|
|
|
|
|
|
|
| 44 |
return fig
|
| 45 |
|
| 46 |
-
def
|
| 47 |
try:
|
| 48 |
-
|
| 49 |
-
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
|
| 53 |
-
#
|
| 54 |
-
inputs = [
|
| 55 |
-
|
| 56 |
-
# Run inference using the v2.5 forecast method
|
| 57 |
point_forecast, _ = model.forecast(
|
| 58 |
-
inputs=inputs,
|
| 59 |
-
horizon=
|
| 60 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
csv_path = f"{ticker}_forecast.csv"
|
| 66 |
-
pd.DataFrame({'Forecast': forecast_values}).to_csv(csv_path, index=False)
|
| 67 |
-
|
| 68 |
-
# Signals
|
| 69 |
-
if mode != "Backtest (Reality Check)":
|
| 70 |
-
pct = ((forecast_values[-1] - train_df['Close'].iloc[-1]) / train_df['Close'].iloc[-1]) * 100
|
| 71 |
-
signal = f"<h3 style='color: {'#00ff88' if pct > 0 else '#ff4444'};'>{ 'BULLISH' if pct > 0 else 'BEARISH' } ({pct:+.2f}%)</h3>"
|
| 72 |
else:
|
| 73 |
-
|
| 74 |
-
|
|
|
|
| 75 |
|
| 76 |
-
|
|
|
|
|
|
|
| 77 |
|
|
|
|
|
|
|
|
|
|
| 78 |
except Exception as e:
|
| 79 |
-
return None, f"
|
| 80 |
|
| 81 |
-
# --- UI ---
|
| 82 |
-
with gr.Blocks(theme=gr.themes.
|
| 83 |
-
gr.HTML("<
|
|
|
|
|
|
|
| 84 |
with gr.Row():
|
| 85 |
with gr.Column(scale=1):
|
| 86 |
-
ticker_in = gr.Textbox(label="
|
| 87 |
-
mode_in = gr.Radio(["Future Forecast", "Backtest
|
| 88 |
-
days_in = gr.Slider(
|
| 89 |
-
btn = gr.Button("RUN ANALYSIS", variant="primary")
|
| 90 |
-
|
| 91 |
-
|
|
|
|
| 92 |
with gr.Column(scale=4):
|
| 93 |
plot_out = gr.Plot()
|
| 94 |
-
|
|
|
|
| 95 |
|
| 96 |
demo.launch()
|
|
|
|
| 6 |
from plotly.subplots import make_subplots
|
| 7 |
import timesfm
|
| 8 |
|
| 9 |
+
# --- SDK v2.5 INITIALIZATION ---
|
| 10 |
+
# Load the 200M parameter model specifically for PyTorch
|
| 11 |
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
|
| 12 |
"google/timesfm-2.5-200m-pytorch"
|
| 13 |
)
|
| 14 |
|
| 15 |
+
# FIXED: Parameters are now moved into ForecastConfig
|
| 16 |
model.compile(timesfm.ForecastConfig(
|
| 17 |
+
max_context=1024, # v2.5 now supports up to 16k
|
| 18 |
+
max_horizon=256,
|
| 19 |
normalize_inputs=True,
|
| 20 |
infer_is_positive=True
|
| 21 |
))
|
| 22 |
|
| 23 |
+
def generate_professional_chart(df, forecast_data, ticker, is_backtest=False):
|
| 24 |
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
|
| 25 |
+
vertical_spacing=0.03, row_heights=[0.8, 0.2])
|
| 26 |
|
| 27 |
+
hist_df = df[:-30] if is_backtest else df
|
| 28 |
|
| 29 |
+
# Historical Trend
|
| 30 |
+
fig.add_trace(go.Scatter(x=hist_df.index, y=hist_df['Close'], name='Market Price',
|
| 31 |
+
line=dict(color='#2962FF', width=2)), row=1, col=1)
|
| 32 |
|
| 33 |
+
# Reality Check (Dotted)
|
| 34 |
if is_backtest:
|
| 35 |
+
fig.add_trace(go.Scatter(x=df.index[-30:], y=df['Close'][-30:], name='Actual Path',
|
| 36 |
+
line=dict(color='#787b86', dash='dot')), row=1, col=1)
|
| 37 |
|
| 38 |
+
# AI Forecast Line
|
| 39 |
+
fc_dates = pd.date_range(start=hist_df.index[-1], periods=len(forecast_data)+1, freq='B')[1:]
|
| 40 |
+
fig.add_trace(go.Scatter(x=fc_dates, y=forecast_data, name='AI Forecast',
|
| 41 |
+
line=dict(color='#F23645', width=3)), row=1, col=1)
|
| 42 |
|
| 43 |
+
# Volume Bars
|
| 44 |
+
fig.add_trace(go.Bar(x=df.index, y=df['Volume'], name='Volume',
|
| 45 |
+
marker_color='rgba(120, 123, 134, 0.3)'), row=2, col=1)
|
| 46 |
|
| 47 |
+
fig.update_layout(template='plotly_dark', paper_bgcolor='#131722', plot_bgcolor='#131722',
|
| 48 |
+
margin=dict(l=50, r=50, t=30, b=30), hovermode="x unified")
|
| 49 |
+
fig.update_yaxes(side='right', gridcolor='#2a2e39')
|
| 50 |
+
fig.update_xaxes(gridcolor='#2a2e39')
|
| 51 |
return fig
|
| 52 |
|
| 53 |
+
def run_terminal(ticker, horizon, mode):
|
| 54 |
try:
|
| 55 |
+
# 1. Data Ingestion
|
| 56 |
+
data = yf.download(ticker, period="2y")
|
| 57 |
+
if data.empty: return None, "⚠️ SYMBOL NOT FOUND", None
|
| 58 |
|
| 59 |
+
train_data = data[:-30] if mode == "Backtest Mode" else data
|
| 60 |
|
| 61 |
+
# 2. SDK Inference (Array-based input for v2.5)
|
| 62 |
+
inputs = [train_data['Close'].values]
|
|
|
|
|
|
|
| 63 |
point_forecast, _ = model.forecast(
|
| 64 |
+
inputs=inputs,
|
| 65 |
+
horizon=30 if mode == "Backtest Mode" else horizon
|
| 66 |
)
|
| 67 |
+
prediction = point_forecast[0]
|
| 68 |
+
|
| 69 |
+
# 3. Financial Metrics
|
| 70 |
+
pct_change = ((prediction[-1] - train_data['Close'].iloc[-1]) / train_data['Close'].iloc[-1]) * 100
|
| 71 |
|
| 72 |
+
if mode == "Future Forecast":
|
| 73 |
+
signal = f"<h2 style='color: {'#00ff88' if pct_change > 0 else '#ff4444'}; text-align: center;'>" \
|
| 74 |
+
f"{'BULLISH' if pct_change > 0 else 'BEARISH'} ({pct_change:+.2f}%)</h2>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
else:
|
| 76 |
+
# Calculate Backtest Accuracy
|
| 77 |
+
accuracy = 100 - abs((data['Close'].iloc[-1] - prediction[-1]) / data['Close'].iloc[-1] * 100)
|
| 78 |
+
signal = f"<h2 style='color: #FFD700; text-align: center;'>AI ACCURACY: {accuracy:.1f}%</h2>"
|
| 79 |
|
| 80 |
+
# 4. Generate Report
|
| 81 |
+
report_path = f"{ticker}_AI_Report.csv"
|
| 82 |
+
pd.DataFrame({'Forecast_Price': prediction}).to_csv(report_path, index=False)
|
| 83 |
|
| 84 |
+
chart = generate_professional_chart(data, prediction, ticker, (mode == "Backtest Mode"))
|
| 85 |
+
return chart, signal, report_path
|
| 86 |
+
|
| 87 |
except Exception as e:
|
| 88 |
+
return None, f"<div style='color:red;'>API Error: {str(e)}</div>", None
|
| 89 |
|
| 90 |
+
# --- TERMINAL UI ---
|
| 91 |
+
with gr.Blocks(title="G-TIMES QUANT 2.5", theme=gr.themes.Base()) as demo:
|
| 92 |
+
gr.HTML("<div style='background-color:#131722; padding:20px; border-bottom:3px solid #2962FF; text-align:center;'>"
|
| 93 |
+
"<h1 style='color:white; letter-spacing:3px;'>G-TIMES <span style='color:#2962FF;'>QUANT 2.5</span></h1></div>")
|
| 94 |
+
|
| 95 |
with gr.Row():
|
| 96 |
with gr.Column(scale=1):
|
| 97 |
+
ticker_in = gr.Textbox(label="TICKER", value="TSLA")
|
| 98 |
+
mode_in = gr.Radio(["Future Forecast", "Backtest Mode"], label="STRATEGY", value="Future Forecast")
|
| 99 |
+
days_in = gr.Slider(7, 128, value=30, label="HORIZON")
|
| 100 |
+
btn = gr.Button("RUN QUANT ANALYSIS", variant="primary")
|
| 101 |
+
status_out = gr.HTML()
|
| 102 |
+
file_out = gr.File(label="CSV EXPORT")
|
| 103 |
+
|
| 104 |
with gr.Column(scale=4):
|
| 105 |
plot_out = gr.Plot()
|
| 106 |
+
|
| 107 |
+
btn.click(run_terminal, [ticker_in, days_in, mode_in], [plot_out, status_out, file_out])
|
| 108 |
|
| 109 |
demo.launch()
|