pyroleli commited on
Commit
87caa42
·
verified ·
1 Parent(s): da843f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -109
app.py CHANGED
@@ -1,17 +1,16 @@
1
  import gradio as gr
2
  import pandas as pd
3
- import torch
4
- import plotly.graph_objects as go
5
- from plotly.subplots import make_subplots
6
  import numpy as np
7
  import yfinance as yf
 
 
8
  from timesfm import TimesFm, TimesFmHparams, TimesFmCheckpoint
9
 
10
- # --- 1. FIXED INITIALIZATION ---
11
- # The new API requires 'hparams' and 'checkpoint' objects, not flat arguments.
12
  tfm = TimesFm(
13
  hparams=TimesFmHparams(
14
- backend="cpu",
15
  per_core_batch_size=32,
16
  horizon_len=128,
17
  context_len=512,
@@ -19,150 +18,82 @@ tfm = TimesFm(
19
  model_dims=1280,
20
  ),
21
  checkpoint=TimesFmCheckpoint(
22
- huggingface_repo_id="google/timesfm-1.0-200m"
23
  ),
24
  )
25
 
26
  def get_financial_plot(df, forecast_df, ticker, is_backtest=False):
27
- # Create Professional 2-Row Chart (Price & Volume)
28
  fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
29
  vertical_spacing=0.05, row_heights=[0.75, 0.25])
30
 
31
- # A. Historical Data (Blue)
32
- # If backtesting, we only show 'training' data in blue to prove the AI didn't see the rest
33
  display_df = df[:-30] if is_backtest else df
34
 
35
- fig.add_trace(go.Scatter(
36
- x=display_df.index, y=display_df['Close'],
37
- name='Historical Price',
38
- line=dict(color='#2962FF', width=2)
39
- ), row=1, col=1)
40
 
41
- # B. The "Truth" (Grey Line - Only for Backtesting)
42
  if is_backtest:
43
- truth_df = df[-30:]
44
- fig.add_trace(go.Scatter(
45
- x=truth_df.index, y=truth_df['Close'],
46
- name='Actual Market Move',
47
- line=dict(color='#787b86', width=2, dash='dot')
48
- ), row=1, col=1)
49
 
50
- # C. AI Forecast (Red)
51
- # Connect the forecast line to the last historical point for a seamless look
52
- last_hist_date = display_df.index[-1]
53
- last_hist_val = display_df['Close'].iloc[-1]
54
-
55
- # Prepend the last historical point to the forecast data
56
- fc_dates = [last_hist_date] + list(forecast_df['ds'])
57
- fc_vals = [last_hist_val] + list(forecast_df['timesfm'])
58
-
59
- fig.add_trace(go.Scatter(
60
- x=fc_dates, y=fc_vals,
61
- name='AI Prediction',
62
- line=dict(color='#F23645', width=3)
63
- ), row=1, col=1)
64
 
65
- # D. Volume (Bars)
66
- fig.add_trace(go.Bar(
67
- x=df.index, y=df['Volume'],
68
- name='Volume',
69
- marker_color='rgba(38, 166, 154, 0.5)'
70
- ), row=2, col=1)
71
 
72
- # TradingView-Style Dark Theme
73
- fig.update_layout(
74
- template='plotly_dark',
75
- paper_bgcolor='#131722',
76
- plot_bgcolor='#131722',
77
- margin=dict(l=10, r=10, t=40, b=10),
78
- legend=dict(orientation="h", y=1.02, x=0),
79
- hovermode="x unified"
80
- )
81
- fig.update_yaxes(gridcolor='#2a2e39', side='right')
82
- fig.update_xaxes(gridcolor='#2a2e39')
83
-
84
  return fig
85
 
86
  def run_analysis(ticker, horizon, mode):
87
  try:
88
- # Fetch Data
89
- df = yf.download(ticker, period="2y") # Get more data for stability
90
  if df.empty: return None, "⚠️ Ticker Not Found"
91
 
92
- # --- MODE LOGIC ---
93
- if mode == "Backtest (Reality Check)":
94
- # Hide the last 30 days from the AI
95
- train_df = df[:-30]
96
- horizon = 30 # Fixed horizon for backtest comparison
97
- else:
98
- # Use full data for real future prediction
99
- train_df = df
100
 
101
- # Format for TimesFM
102
  input_df = pd.DataFrame({
103
  'unique_id': [ticker],
104
  'ds': train_df.index,
105
  'y': train_df['Close'].values
106
  })
107
 
108
- # Run Forecast
109
  forecast_df, _ = tfm.forecast_on_df(
110
  inputs=input_df,
111
  freq="D",
112
- value_name="y",
113
- forecast_context_len=512 # Explicitly use context
114
  )
115
- forecast_df = forecast_df.head(horizon)
116
 
117
- # Generate Signal (Only for Future Mode)
118
  if mode != "Backtest (Reality Check)":
119
- start_price = train_df['Close'].iloc[-1]
120
- end_price = forecast_df['timesfm'].iloc[-1]
121
- pct_change = ((end_price - start_price) / start_price) * 100
122
-
123
- color = "#00ff88" if pct_change > 0 else "#ff4444"
124
- direction = "BULLISH" if pct_change > 0 else "BEARISH"
125
- signal_html = f"<h3 style='color: {color}; margin: 0;'>{direction} ({pct_change:+.2f}%)</h3>"
126
  else:
127
- # Calculate Accuracy Score for Backtest
128
- real_end = df['Close'].iloc[-1]
129
- pred_end = forecast_df['timesfm'].iloc[-1]
130
- error = abs(real_end - pred_end) / real_end * 100
131
- accuracy = 100 - error
132
- signal_html = f"<h3 style='color: #FFD700; margin: 0;'>AI Accuracy: {accuracy:.1f}%</h3>"
133
 
134
- fig = get_financial_plot(df, forecast_df, ticker, is_backtest=(mode == "Backtest (Reality Check)"))
135
-
136
- return fig, signal_html
137
 
138
  except Exception as e:
139
- return None, f"<span style='color:red'>Error: {str(e)}</span>"
140
 
141
- # UI Layout
142
  with gr.Blocks(theme=gr.themes.Default(), css=".gradio-container {background-color: #000000}") as demo:
143
- gr.HTML("""
144
- <div style='background-color: #131722; padding: 15px; border-bottom: 2px solid #2962FF;'>
145
- <h2 style='color: white; margin:0; text-align: center; font-family: sans-serif;'>
146
- G-TIMES <span style='color: #2962FF;'>PRO TERMINAL</span>
147
- </h2>
148
- </div>
149
- """)
150
-
151
  with gr.Row():
152
- with gr.Column(scale=1, min_width=300):
153
- with gr.Group():
154
- ticker_in = gr.Textbox(label="SYMBOL", value="BTC-USD")
155
- mode_in = gr.Radio(["Future Forecast", "Backtest (Reality Check)"],
156
- label="ANALYSIS MODE", value="Future Forecast")
157
- days_in = gr.Slider(5, 128, value=14, label="Horizon (Days)")
158
-
159
- btn = gr.Button("EXECUTE STRATEGY", variant="primary")
160
-
161
- gr.HTML("<br>")
162
- result_box = gr.HTML(label="Signal")
163
-
164
  with gr.Column(scale=4):
165
- plot_out = gr.Plot(label="Technical Chart")
166
 
167
  btn.click(run_analysis, [ticker_in, days_in, mode_in], [plot_out, result_box])
168
 
 
1
  import gradio as gr
2
  import pandas as pd
 
 
 
3
  import numpy as np
4
  import yfinance as yf
5
+ import plotly.graph_objects as go
6
+ from plotly.subplots import make_subplots
7
  from timesfm import TimesFm, TimesFmHparams, TimesFmCheckpoint
8
 
9
+ # --- JAX INITIALIZATION ---
10
+ # Using the native JAX/Flax checkpoint
11
  tfm = TimesFm(
12
  hparams=TimesFmHparams(
13
+ backend="cpu", # JAX handles CPU/GPU automatically
14
  per_core_batch_size=32,
15
  horizon_len=128,
16
  context_len=512,
 
18
  model_dims=1280,
19
  ),
20
  checkpoint=TimesFmCheckpoint(
21
+ huggingface_repo_id="google/timesfm-1.0-200m" # Original JAX weights
22
  ),
23
  )
24
 
25
  def get_financial_plot(df, forecast_df, ticker, is_backtest=False):
 
26
  fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
27
  vertical_spacing=0.05, row_heights=[0.75, 0.25])
28
 
29
+ # Plotting logic remains consistent for professional look
 
30
  display_df = df[:-30] if is_backtest else df
31
 
32
+ # 1. Price
33
+ fig.add_trace(go.Scatter(x=display_df.index, y=display_df['Close'], name='History', line=dict(color='#2962FF')), row=1, col=1)
 
 
 
34
 
35
+ # 2. Backtest Reality (Truth)
36
  if is_backtest:
37
+ fig.add_trace(go.Scatter(x=df.index[-30:], y=df['Close'][-30:], name='Actual', line=dict(color='#787b86', dash='dot')), row=1, col=1)
 
 
 
 
 
38
 
39
+ # 3. Forecast
40
+ fc_dates = [display_df.index[-1]] + list(forecast_df['ds'])
41
+ fc_vals = [display_df['Close'].iloc[-1]] + list(forecast_df['timesfm'])
42
+ fig.add_trace(go.Scatter(x=fc_dates, y=fc_vals, name='AI Forecast', line=dict(color='#F23645', width=3)), row=1, col=1)
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # 4. Volume
45
+ fig.add_trace(go.Bar(x=df.index, y=df['Volume'], name='Volume', marker_color='rgba(38, 166, 154, 0.5)'), row=2, col=1)
 
 
 
 
46
 
47
+ fig.update_layout(template='plotly_dark', paper_bgcolor='#131722', plot_bgcolor='#131722', margin=dict(l=10, r=10, t=40, b=10))
 
 
 
 
 
 
 
 
 
 
 
48
  return fig
49
 
50
  def run_analysis(ticker, horizon, mode):
51
  try:
52
+ df = yf.download(ticker, period="2y")
 
53
  if df.empty: return None, "⚠️ Ticker Not Found"
54
 
55
+ train_df = df[:-30] if mode == "Backtest (Reality Check)" else df
 
 
 
 
 
 
 
56
 
57
+ # Prepare inputs for TimesFM
58
  input_df = pd.DataFrame({
59
  'unique_id': [ticker],
60
  'ds': train_df.index,
61
  'y': train_df['Close'].values
62
  })
63
 
64
+ # Forecast Execution (Native JAX backend)
65
  forecast_df, _ = tfm.forecast_on_df(
66
  inputs=input_df,
67
  freq="D",
68
+ value_name="y"
 
69
  )
70
+ forecast_df = forecast_df.head(horizon if mode != "Backtest (Reality Check)" else 30)
71
 
72
+ # Signal Generation
73
  if mode != "Backtest (Reality Check)":
74
+ pct = ((forecast_df['timesfm'].iloc[-1] - train_df['Close'].iloc[-1]) / train_df['Close'].iloc[-1]) * 100
75
+ signal = f"<h3 style='color: {'#00ff88' if pct > 0 else '#ff4444'};'>{ 'BULLISH' if pct > 0 else 'BEARISH' } ({pct:+.2f}%)</h3>"
 
 
 
 
 
76
  else:
77
+ acc = 100 - (abs(df['Close'].iloc[-1] - forecast_df['timesfm'].iloc[-1]) / df['Close'].iloc[-1] * 100)
78
+ signal = f"<h3 style='color: #FFD700;'>Model Confidence: {acc:.1f}%</h3>"
 
 
 
 
79
 
80
+ return get_financial_plot(df, forecast_df, ticker, (mode == "Backtest (Reality Check)")), signal
 
 
81
 
82
  except Exception as e:
83
+ return None, f"Runtime Error: {str(e)}"
84
 
85
+ # --- UI Layout ---
86
  with gr.Blocks(theme=gr.themes.Default(), css=".gradio-container {background-color: #000000}") as demo:
87
+ gr.HTML("<h2 style='color: #2962FF; text-align: center;'>G-TIMES JAX TERMINAL</h2>")
 
 
 
 
 
 
 
88
  with gr.Row():
89
+ with gr.Column(scale=1):
90
+ ticker_in = gr.Textbox(label="SYMBOL", value="NVDA")
91
+ mode_in = gr.Radio(["Future Forecast", "Backtest (Reality Check)"], label="MODE", value="Future Forecast")
92
+ days_in = gr.Slider(5, 128, value=30, label="Days")
93
+ btn = gr.Button("RUN JAX INFERENCE", variant="primary")
94
+ result_box = gr.HTML()
 
 
 
 
 
 
95
  with gr.Column(scale=4):
96
+ plot_out = gr.Plot()
97
 
98
  btn.click(run_analysis, [ticker_in, days_in, mode_in], [plot_out, result_box])
99