pyroleli commited on
Commit
0a114f9
·
verified ·
1 Parent(s): 5e546fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -67
app.py CHANGED
@@ -5,121 +5,165 @@ import plotly.graph_objects as go
5
  from plotly.subplots import make_subplots
6
  import numpy as np
7
  import yfinance as yf
8
- from timesfm import TimesFm
9
 
10
- # Initialize Google TimesFM with updated API arguments
11
- # Note: 'context_len' changed to 'context_length', etc.
12
  tfm = TimesFm(
13
- context_length=512,
14
- horizon_length=128,
15
- input_patch_len=32,
16
- output_patch_len=128,
17
- num_layers=20,
18
- model_dims=1280,
19
- backend="cpu",
 
 
 
 
20
  )
21
 
22
- tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")
23
-
24
- def get_tradingview_plot(df, forecast_df, ticker):
25
- # Create subplots: Chart on top (80%), Volume on bottom (20%)
26
  fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
27
- vertical_spacing=0.03, subplot_titles=(f'{ticker} Analysis', ''),
28
- row_width=[0.2, 0.8])
29
 
30
- # 1. Main Price Line
 
 
 
31
  fig.add_trace(go.Scatter(
32
- x=df.index, y=df['Close'],
33
- name='Historical Close',
34
  line=dict(color='#2962FF', width=2)
35
  ), row=1, col=1)
36
 
37
- # 2. TimesFM Forecast (Dotted Projection)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  fig.add_trace(go.Scatter(
39
- x=forecast_df['ds'], y=forecast_df['timesfm'],
40
- name='AI Forecast',
41
- line=dict(color='#F23645', width=3, dash='dot')
42
  ), row=1, col=1)
43
 
44
- # 3. Volume Profile (Bars)
45
  fig.add_trace(go.Bar(
46
  x=df.index, y=df['Volume'],
47
  name='Volume',
48
- marker_color='#26a69a',
49
- opacity=0.5
50
  ), row=2, col=1)
51
 
52
- # Professional Dark Theme Styling
53
  fig.update_layout(
54
  template='plotly_dark',
55
- hovermode="x unified",
56
  paper_bgcolor='#131722',
57
  plot_bgcolor='#131722',
58
- showlegend=True,
59
- margin=dict(l=10, r=10, t=50, b=10),
60
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
61
  )
62
-
63
- fig.update_xaxes(showgrid=True, gridcolor='#2a2e39', rangeslider_visible=False)
64
- fig.update_yaxes(showgrid=True, gridcolor='#2a2e39', side='right')
65
 
66
  return fig
67
 
68
- def run_analysis(ticker, horizon):
69
  try:
70
- df = yf.download(ticker, period="1y")
71
- if df.empty: return None, "⚠️ Invalid Ticker"
 
 
 
 
 
 
 
 
 
 
72
 
73
  # Format for TimesFM
74
  input_df = pd.DataFrame({
75
  'unique_id': [ticker],
76
- 'ds': df.index,
77
- 'y': df['Close'].values
78
  })
79
 
80
- # Execute Forecast
81
  forecast_df, _ = tfm.forecast_on_df(
82
  inputs=input_df,
83
  freq="D",
84
- value_name="y"
 
85
  )
86
-
87
  forecast_df = forecast_df.head(horizon)
88
 
89
- # Calculate Signals
90
- current_price = df['Close'].iloc[-1]
91
- final_pred = forecast_df['timesfm'].iloc[-1]
92
- change_pct = ((final_pred - current_price) / current_price) * 100
93
-
94
- signal = "STRONG BUY" if change_pct > 5 else "BUY" if change_pct > 0 else "SELL"
95
- signal_color = "#00ff88" if "BUY" in signal else "#ff4444"
96
-
97
- fig = get_tradingview_plot(df, forecast_df, ticker)
 
 
 
 
 
 
 
 
 
98
 
99
- html_signal = f"<h2 style='color: {signal_color}; text-align: center;'>{signal} ({change_pct:+.2f}%)</h2>"
100
 
101
- return fig, html_signal
102
  except Exception as e:
103
- return None, f"Error: {str(e)}"
104
 
105
- # Professional Layout
106
- with gr.Blocks(theme=gr.themes.Default(), css=".gradio-container {background-color: #131722; color: white;}") as demo:
107
- gr.HTML("<h1 style='text-align: center; color: #2962FF; font-family: sans-serif;'>G-TIMES QUANT TERMINAL</h1>")
 
 
 
 
 
 
108
 
109
  with gr.Row():
110
- with gr.Column(scale=1):
111
- with gr.Box():
112
- ticker_input = gr.Textbox(label="Ticker Symbol", value="NVDA")
113
- horizon_slider = gr.Slider(5, 128, value=30, label="Forecast Horizon (Days)")
114
- analyze_btn = gr.Button("RUN QUANT ANALYSIS", variant="primary")
 
 
 
115
 
116
- output_signal = gr.HTML(label="Market Signal")
117
- gr.Markdown("---")
118
- gr.Markdown("### Terminal Info\n- **Model**: Google TimesFM-1.0\n- **Backend**: CPU PyTorch\n- **Freq**: Daily Close")
119
 
120
  with gr.Column(scale=4):
121
- plot_output = gr.Plot()
122
 
123
- analyze_btn.click(run_analysis, [ticker_input, horizon_slider], [plot_output, output_signal])
124
 
125
  demo.launch()
 
5
  from plotly.subplots import make_subplots
6
  import numpy as np
7
  import yfinance as yf
8
+ from timesfm import TimesFm, TimesFmHparams, TimesFmCheckpoint
9
 
10
+ # --- 1. FIXED INITIALIZATION ---
11
+ # The new API requires 'hparams' and 'checkpoint' objects, not flat arguments.
12
  tfm = TimesFm(
13
+ hparams=TimesFmHparams(
14
+ backend="cpu",
15
+ per_core_batch_size=32,
16
+ horizon_len=128,
17
+ context_len=512,
18
+ num_layers=20,
19
+ model_dims=1280,
20
+ ),
21
+ checkpoint=TimesFmCheckpoint(
22
+ huggingface_repo_id="google/timesfm-1.0-200m"
23
+ ),
24
  )
25
 
26
+ def get_financial_plot(df, forecast_df, ticker, is_backtest=False):
27
+ # Create Professional 2-Row Chart (Price & Volume)
 
 
28
  fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
29
+ vertical_spacing=0.05, row_heights=[0.75, 0.25])
 
30
 
31
+ # A. Historical Data (Blue)
32
+ # If backtesting, we only show 'training' data in blue to prove the AI didn't see the rest
33
+ display_df = df[:-30] if is_backtest else df
34
+
35
  fig.add_trace(go.Scatter(
36
+ x=display_df.index, y=display_df['Close'],
37
+ name='Historical Price',
38
  line=dict(color='#2962FF', width=2)
39
  ), row=1, col=1)
40
 
41
+ # B. The "Truth" (Grey Line - Only for Backtesting)
42
+ if is_backtest:
43
+ truth_df = df[-30:]
44
+ fig.add_trace(go.Scatter(
45
+ x=truth_df.index, y=truth_df['Close'],
46
+ name='Actual Market Move',
47
+ line=dict(color='#787b86', width=2, dash='dot')
48
+ ), row=1, col=1)
49
+
50
+ # C. AI Forecast (Red)
51
+ # Connect the forecast line to the last historical point for a seamless look
52
+ last_hist_date = display_df.index[-1]
53
+ last_hist_val = display_df['Close'].iloc[-1]
54
+
55
+ # Prepend the last historical point to the forecast data
56
+ fc_dates = [last_hist_date] + list(forecast_df['ds'])
57
+ fc_vals = [last_hist_val] + list(forecast_df['timesfm'])
58
+
59
  fig.add_trace(go.Scatter(
60
+ x=fc_dates, y=fc_vals,
61
+ name='AI Prediction',
62
+ line=dict(color='#F23645', width=3)
63
  ), row=1, col=1)
64
 
65
+ # D. Volume (Bars)
66
  fig.add_trace(go.Bar(
67
  x=df.index, y=df['Volume'],
68
  name='Volume',
69
+ marker_color='rgba(38, 166, 154, 0.5)'
 
70
  ), row=2, col=1)
71
 
72
+ # TradingView-Style Dark Theme
73
  fig.update_layout(
74
  template='plotly_dark',
 
75
  paper_bgcolor='#131722',
76
  plot_bgcolor='#131722',
77
+ margin=dict(l=10, r=10, t=40, b=10),
78
+ legend=dict(orientation="h", y=1.02, x=0),
79
+ hovermode="x unified"
80
  )
81
+ fig.update_yaxes(gridcolor='#2a2e39', side='right')
82
+ fig.update_xaxes(gridcolor='#2a2e39')
 
83
 
84
  return fig
85
 
86
+ def run_analysis(ticker, horizon, mode):
87
  try:
88
+ # Fetch Data
89
+ df = yf.download(ticker, period="2y") # Get more data for stability
90
+ if df.empty: return None, "⚠️ Ticker Not Found"
91
+
92
+ # --- MODE LOGIC ---
93
+ if mode == "Backtest (Reality Check)":
94
+ # Hide the last 30 days from the AI
95
+ train_df = df[:-30]
96
+ horizon = 30 # Fixed horizon for backtest comparison
97
+ else:
98
+ # Use full data for real future prediction
99
+ train_df = df
100
 
101
  # Format for TimesFM
102
  input_df = pd.DataFrame({
103
  'unique_id': [ticker],
104
+ 'ds': train_df.index,
105
+ 'y': train_df['Close'].values
106
  })
107
 
108
+ # Run Forecast
109
  forecast_df, _ = tfm.forecast_on_df(
110
  inputs=input_df,
111
  freq="D",
112
+ value_name="y",
113
+ forecast_context_len=512 # Explicitly use context
114
  )
 
115
  forecast_df = forecast_df.head(horizon)
116
 
117
+ # Generate Signal (Only for Future Mode)
118
+ if mode != "Backtest (Reality Check)":
119
+ start_price = train_df['Close'].iloc[-1]
120
+ end_price = forecast_df['timesfm'].iloc[-1]
121
+ pct_change = ((end_price - start_price) / start_price) * 100
122
+
123
+ color = "#00ff88" if pct_change > 0 else "#ff4444"
124
+ direction = "BULLISH" if pct_change > 0 else "BEARISH"
125
+ signal_html = f"<h3 style='color: {color}; margin: 0;'>{direction} ({pct_change:+.2f}%)</h3>"
126
+ else:
127
+ # Calculate Accuracy Score for Backtest
128
+ real_end = df['Close'].iloc[-1]
129
+ pred_end = forecast_df['timesfm'].iloc[-1]
130
+ error = abs(real_end - pred_end) / real_end * 100
131
+ accuracy = 100 - error
132
+ signal_html = f"<h3 style='color: #FFD700; margin: 0;'>AI Accuracy: {accuracy:.1f}%</h3>"
133
+
134
+ fig = get_financial_plot(df, forecast_df, ticker, is_backtest=(mode == "Backtest (Reality Check)"))
135
 
136
+ return fig, signal_html
137
 
 
138
  except Exception as e:
139
+ return None, f"<span style='color:red'>Error: {str(e)}</span>"
140
 
141
+ # UI Layout
142
+ with gr.Blocks(theme=gr.themes.Default(), css=".gradio-container {background-color: #000000}") as demo:
143
+ gr.HTML("""
144
+ <div style='background-color: #131722; padding: 15px; border-bottom: 2px solid #2962FF;'>
145
+ <h2 style='color: white; margin:0; text-align: center; font-family: sans-serif;'>
146
+ G-TIMES <span style='color: #2962FF;'>PRO TERMINAL</span>
147
+ </h2>
148
+ </div>
149
+ """)
150
 
151
  with gr.Row():
152
+ with gr.Column(scale=1, min_width=300):
153
+ with gr.Group():
154
+ ticker_in = gr.Textbox(label="SYMBOL", value="BTC-USD")
155
+ mode_in = gr.Radio(["Future Forecast", "Backtest (Reality Check)"],
156
+ label="ANALYSIS MODE", value="Future Forecast")
157
+ days_in = gr.Slider(5, 128, value=14, label="Horizon (Days)")
158
+
159
+ btn = gr.Button("EXECUTE STRATEGY", variant="primary")
160
 
161
+ gr.HTML("<br>")
162
+ result_box = gr.HTML(label="Signal")
 
163
 
164
  with gr.Column(scale=4):
165
+ plot_out = gr.Plot(label="Technical Chart")
166
 
167
+ btn.click(run_analysis, [ticker_in, days_in, mode_in], [plot_out, result_box])
168
 
169
  demo.launch()