pyroleli commited on
Commit
431a4a5
·
verified ·
1 Parent(s): be937ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -52
app.py CHANGED
@@ -6,91 +6,104 @@ import plotly.graph_objects as go
6
  from plotly.subplots import make_subplots
7
  import timesfm
8
 
9
- # --- UPDATED INITIALIZATION (v2.5 API) ---
10
- # We use the new torch-specific class and ForecastConfig object
11
  model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
12
  "google/timesfm-2.5-200m-pytorch"
13
  )
14
 
15
- # Configure the model using the new ForecastConfig
16
  model.compile(timesfm.ForecastConfig(
17
- max_context=512,
18
- max_horizon=128,
19
  normalize_inputs=True,
20
  infer_is_positive=True
21
  ))
22
 
23
- def get_financial_plot(df, forecast_df, ticker, is_backtest=False):
24
  fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
25
- vertical_spacing=0.05, row_heights=[0.75, 0.25])
26
 
27
- display_df = df[:-30] if is_backtest else df
28
 
29
- # 1. Historical Price
30
- fig.add_trace(go.Scatter(x=display_df.index, y=display_df['Close'], name='History', line=dict(color='#2962FF')), row=1, col=1)
 
31
 
32
- # 2. Backtest Truth
33
  if is_backtest:
34
- fig.add_trace(go.Scatter(x=df.index[-30:], y=df['Close'][-30:], name='Actual', line=dict(color='#787b86', dash='dot')), row=1, col=1)
 
35
 
36
- # 3. AI Forecast
37
- fc_dates = pd.date_range(start=display_df.index[-1], periods=len(forecast_df)+1, freq='D')[1:]
38
- fig.add_trace(go.Scatter(x=fc_dates, y=forecast_df, name='AI Forecast', line=dict(color='#F23645', width=3)), row=1, col=1)
 
39
 
40
- # 4. Volume
41
- fig.add_trace(go.Bar(x=df.index, y=df['Volume'], name='Volume', marker_color='rgba(38, 166, 154, 0.5)'), row=2, col=1)
 
42
 
43
- fig.update_layout(template='plotly_dark', paper_bgcolor='#131722', plot_bgcolor='#131722', margin=dict(l=10, r=10, t=40, b=10))
 
 
 
44
  return fig
45
 
46
- def run_analysis(ticker, horizon, mode):
47
  try:
48
- df = yf.download(ticker, period="2y")
49
- if df.empty: return None, "⚠️ Ticker Not Found", None
 
50
 
51
- train_df = df[:-30] if mode == "Backtest (Reality Check)" else df
52
 
53
- # New API uses simple array inputs
54
- inputs = [train_df['Close'].values]
55
-
56
- # Run inference using the v2.5 forecast method
57
  point_forecast, _ = model.forecast(
58
- inputs=inputs,
59
- horizon=horizon if mode != "Backtest (Reality Check)" else 30
60
  )
 
 
 
 
61
 
62
- forecast_values = point_forecast[0] # Get results for the first (only) input
63
-
64
- # Save CSV
65
- csv_path = f"{ticker}_forecast.csv"
66
- pd.DataFrame({'Forecast': forecast_values}).to_csv(csv_path, index=False)
67
-
68
- # Signals
69
- if mode != "Backtest (Reality Check)":
70
- pct = ((forecast_values[-1] - train_df['Close'].iloc[-1]) / train_df['Close'].iloc[-1]) * 100
71
- signal = f"<h3 style='color: {'#00ff88' if pct > 0 else '#ff4444'};'>{ 'BULLISH' if pct > 0 else 'BEARISH' } ({pct:+.2f}%)</h3>"
72
  else:
73
- acc = 100 - (abs(df['Close'].iloc[-1] - forecast_values[-1]) / df['Close'].iloc[-1] * 100)
74
- signal = f"<h3 style='color: #FFD700;'>Accuracy: {acc:.1f}%</h3>"
 
75
 
76
- return get_financial_plot(df, forecast_values, ticker, (mode == "Backtest (Reality Check)")), signal, csv_path
 
 
77
 
 
 
 
78
  except Exception as e:
79
- return None, f"Runtime Error: {str(e)}", None
80
 
81
- # --- UI ---
82
- with gr.Blocks(theme=gr.themes.Default(), css=".gradio-container {background-color: #000000}") as demo:
83
- gr.HTML("<h2 style='color: #2962FF; text-align: center;'>G-TIMES 2.5 QUANT TERMINAL</h2>")
 
 
84
  with gr.Row():
85
  with gr.Column(scale=1):
86
- ticker_in = gr.Textbox(label="SYMBOL", value="NVDA")
87
- mode_in = gr.Radio(["Future Forecast", "Backtest (Reality Check)"], label="MODE", value="Future Forecast")
88
- days_in = gr.Slider(5, 128, value=30, label="Days")
89
- btn = gr.Button("RUN ANALYSIS", variant="primary")
90
- result_box = gr.HTML()
91
- file_output = gr.File(label="Download Forecast")
 
92
  with gr.Column(scale=4):
93
  plot_out = gr.Plot()
94
- btn.click(run_analysis, [ticker_in, days_in, mode_in], [plot_out, result_box, file_output])
 
95
 
96
  demo.launch()
 
6
  from plotly.subplots import make_subplots
7
  import timesfm
8
 
9
+ # --- SDK v2.5 INITIALIZATION ---
10
+ # Load the 200M parameter model specifically for PyTorch
11
  model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
12
  "google/timesfm-2.5-200m-pytorch"
13
  )
14
 
15
+ # FIXED: Parameters are now moved into ForecastConfig
16
  model.compile(timesfm.ForecastConfig(
17
+ max_context=1024, # v2.5 now supports up to 16k
18
+ max_horizon=256,
19
  normalize_inputs=True,
20
  infer_is_positive=True
21
  ))
22
 
23
+ def generate_professional_chart(df, forecast_data, ticker, is_backtest=False):
24
  fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
25
+ vertical_spacing=0.03, row_heights=[0.8, 0.2])
26
 
27
+ hist_df = df[:-30] if is_backtest else df
28
 
29
+ # Historical Trend
30
+ fig.add_trace(go.Scatter(x=hist_df.index, y=hist_df['Close'], name='Market Price',
31
+ line=dict(color='#2962FF', width=2)), row=1, col=1)
32
 
33
+ # Reality Check (Dotted)
34
  if is_backtest:
35
+ fig.add_trace(go.Scatter(x=df.index[-30:], y=df['Close'][-30:], name='Actual Path',
36
+ line=dict(color='#787b86', dash='dot')), row=1, col=1)
37
 
38
+ # AI Forecast Line
39
+ fc_dates = pd.date_range(start=hist_df.index[-1], periods=len(forecast_data)+1, freq='B')[1:]
40
+ fig.add_trace(go.Scatter(x=fc_dates, y=forecast_data, name='AI Forecast',
41
+ line=dict(color='#F23645', width=3)), row=1, col=1)
42
 
43
+ # Volume Bars
44
+ fig.add_trace(go.Bar(x=df.index, y=df['Volume'], name='Volume',
45
+ marker_color='rgba(120, 123, 134, 0.3)'), row=2, col=1)
46
 
47
+ fig.update_layout(template='plotly_dark', paper_bgcolor='#131722', plot_bgcolor='#131722',
48
+ margin=dict(l=50, r=50, t=30, b=30), hovermode="x unified")
49
+ fig.update_yaxes(side='right', gridcolor='#2a2e39')
50
+ fig.update_xaxes(gridcolor='#2a2e39')
51
  return fig
52
 
53
+ def run_terminal(ticker, horizon, mode):
54
  try:
55
+ # 1. Data Ingestion
56
+ data = yf.download(ticker, period="2y")
57
+ if data.empty: return None, "⚠️ SYMBOL NOT FOUND", None
58
 
59
+ train_data = data[:-30] if mode == "Backtest Mode" else data
60
 
61
+ # 2. SDK Inference (Array-based input for v2.5)
62
+ inputs = [train_data['Close'].values]
 
 
63
  point_forecast, _ = model.forecast(
64
+ inputs=inputs,
65
+ horizon=30 if mode == "Backtest Mode" else horizon
66
  )
67
+ prediction = point_forecast[0]
68
+
69
+ # 3. Financial Metrics
70
+ pct_change = ((prediction[-1] - train_data['Close'].iloc[-1]) / train_data['Close'].iloc[-1]) * 100
71
 
72
+ if mode == "Future Forecast":
73
+ signal = f"<h2 style='color: {'#00ff88' if pct_change > 0 else '#ff4444'}; text-align: center;'>" \
74
+ f"{'BULLISH' if pct_change > 0 else 'BEARISH'} ({pct_change:+.2f}%)</h2>"
 
 
 
 
 
 
 
75
  else:
76
+ # Calculate Backtest Accuracy
77
+ accuracy = 100 - abs((data['Close'].iloc[-1] - prediction[-1]) / data['Close'].iloc[-1] * 100)
78
+ signal = f"<h2 style='color: #FFD700; text-align: center;'>AI ACCURACY: {accuracy:.1f}%</h2>"
79
 
80
+ # 4. Generate Report
81
+ report_path = f"{ticker}_AI_Report.csv"
82
+ pd.DataFrame({'Forecast_Price': prediction}).to_csv(report_path, index=False)
83
 
84
+ chart = generate_professional_chart(data, prediction, ticker, (mode == "Backtest Mode"))
85
+ return chart, signal, report_path
86
+
87
  except Exception as e:
88
+ return None, f"<div style='color:red;'>API Error: {str(e)}</div>", None
89
 
90
+ # --- TERMINAL UI ---
91
+ with gr.Blocks(title="G-TIMES QUANT 2.5", theme=gr.themes.Base()) as demo:
92
+ gr.HTML("<div style='background-color:#131722; padding:20px; border-bottom:3px solid #2962FF; text-align:center;'>"
93
+ "<h1 style='color:white; letter-spacing:3px;'>G-TIMES <span style='color:#2962FF;'>QUANT 2.5</span></h1></div>")
94
+
95
  with gr.Row():
96
  with gr.Column(scale=1):
97
+ ticker_in = gr.Textbox(label="TICKER", value="TSLA")
98
+ mode_in = gr.Radio(["Future Forecast", "Backtest Mode"], label="STRATEGY", value="Future Forecast")
99
+ days_in = gr.Slider(7, 128, value=30, label="HORIZON")
100
+ btn = gr.Button("RUN QUANT ANALYSIS", variant="primary")
101
+ status_out = gr.HTML()
102
+ file_out = gr.File(label="CSV EXPORT")
103
+
104
  with gr.Column(scale=4):
105
  plot_out = gr.Plot()
106
+
107
+ btn.click(run_terminal, [ticker_in, days_in, mode_in], [plot_out, status_out, file_out])
108
 
109
  demo.launch()