pyroleli commited on
Commit
20267f0
·
verified ·
1 Parent(s): dabaaa0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -47
app.py CHANGED
@@ -4,42 +4,38 @@ import numpy as np
4
  import yfinance as yf
5
  import plotly.graph_objects as go
6
  from plotly.subplots import make_subplots
7
- from timesfm import TimesFm, TimesFmHparams, TimesFmCheckpoint
8
 
9
- # --- JAX INITIALIZATION ---
10
- # Using the native JAX/Flax checkpoint
11
- tfm = TimesFm(
12
- hparams=TimesFmHparams(
13
- backend="cpu", # JAX handles CPU/GPU automatically
14
- per_core_batch_size=32,
15
- horizon_len=128,
16
- context_len=512,
17
- num_layers=20,
18
- model_dims=1280,
19
- ),
20
- checkpoint=TimesFmCheckpoint(
21
- huggingface_repo_id="google/timesfm-1.0-200m" # Original JAX weights
22
- ),
23
  )
24
 
 
 
 
 
 
 
 
 
25
  def get_financial_plot(df, forecast_df, ticker, is_backtest=False):
26
  fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
27
  vertical_spacing=0.05, row_heights=[0.75, 0.25])
28
 
29
- # Plotting logic remains consistent for professional look
30
  display_df = df[:-30] if is_backtest else df
31
 
32
- # 1. Price
33
  fig.add_trace(go.Scatter(x=display_df.index, y=display_df['Close'], name='History', line=dict(color='#2962FF')), row=1, col=1)
34
 
35
- # 2. Backtest Reality (Truth)
36
  if is_backtest:
37
  fig.add_trace(go.Scatter(x=df.index[-30:], y=df['Close'][-30:], name='Actual', line=dict(color='#787b86', dash='dot')), row=1, col=1)
38
 
39
- # 3. Forecast
40
- fc_dates = [display_df.index[-1]] + list(forecast_df['ds'])
41
- fc_vals = [display_df['Close'].iloc[-1]] + list(forecast_df['timesfm'])
42
- fig.add_trace(go.Scatter(x=fc_dates, y=fc_vals, name='AI Forecast', line=dict(color='#F23645', width=3)), row=1, col=1)
43
 
44
  # 4. Volume
45
  fig.add_trace(go.Bar(x=df.index, y=df['Volume'], name='Volume', marker_color='rgba(38, 166, 154, 0.5)'), row=2, col=1)
@@ -50,51 +46,51 @@ def get_financial_plot(df, forecast_df, ticker, is_backtest=False):
50
  def run_analysis(ticker, horizon, mode):
51
  try:
52
  df = yf.download(ticker, period="2y")
53
- if df.empty: return None, "⚠️ Ticker Not Found"
54
 
55
  train_df = df[:-30] if mode == "Backtest (Reality Check)" else df
56
 
57
- # Prepare inputs for TimesFM
58
- input_df = pd.DataFrame({
59
- 'unique_id': [ticker],
60
- 'ds': train_df.index,
61
- 'y': train_df['Close'].values
62
- })
63
-
64
- # Forecast Execution (Native JAX backend)
65
- forecast_df, _ = tfm.forecast_on_df(
66
- inputs=input_df,
67
- freq="D",
68
- value_name="y"
69
  )
70
- forecast_df = forecast_df.head(horizon if mode != "Backtest (Reality Check)" else 30)
71
 
72
- # Signal Generation
 
 
 
 
 
 
73
  if mode != "Backtest (Reality Check)":
74
- pct = ((forecast_df['timesfm'].iloc[-1] - train_df['Close'].iloc[-1]) / train_df['Close'].iloc[-1]) * 100
75
  signal = f"<h3 style='color: {'#00ff88' if pct > 0 else '#ff4444'};'>{ 'BULLISH' if pct > 0 else 'BEARISH' } ({pct:+.2f}%)</h3>"
76
  else:
77
- acc = 100 - (abs(df['Close'].iloc[-1] - forecast_df['timesfm'].iloc[-1]) / df['Close'].iloc[-1] * 100)
78
- signal = f"<h3 style='color: #FFD700;'>Model Confidence: {acc:.1f}%</h3>"
79
 
80
- return get_financial_plot(df, forecast_df, ticker, (mode == "Backtest (Reality Check)")), signal
81
 
82
  except Exception as e:
83
- return None, f"Runtime Error: {str(e)}"
84
 
85
- # --- UI Layout ---
86
  with gr.Blocks(theme=gr.themes.Default(), css=".gradio-container {background-color: #000000}") as demo:
87
- gr.HTML("<h2 style='color: #2962FF; text-align: center;'>G-TIMES JAX TERMINAL</h2>")
88
  with gr.Row():
89
  with gr.Column(scale=1):
90
  ticker_in = gr.Textbox(label="SYMBOL", value="NVDA")
91
  mode_in = gr.Radio(["Future Forecast", "Backtest (Reality Check)"], label="MODE", value="Future Forecast")
92
  days_in = gr.Slider(5, 128, value=30, label="Days")
93
- btn = gr.Button("RUN JAX INFERENCE", variant="primary")
94
  result_box = gr.HTML()
 
95
  with gr.Column(scale=4):
96
  plot_out = gr.Plot()
97
-
98
- btn.click(run_analysis, [ticker_in, days_in, mode_in], [plot_out, result_box])
99
 
100
  demo.launch()
 
4
  import yfinance as yf
5
  import plotly.graph_objects as go
6
  from plotly.subplots import make_subplots
7
+ import timesfm
8
 
9
+ # --- UPDATED INITIALIZATION (v2.5 API) ---
10
+ # We use the new torch-specific class and ForecastConfig object
11
+ model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
12
+ "google/timesfm-2.5-200m-pytorch"
 
 
 
 
 
 
 
 
 
 
13
  )
14
 
15
+ # Configure the model using the new ForecastConfig
16
+ model.compile(timesfm.ForecastConfig(
17
+ max_context=512,
18
+ max_horizon=128,
19
+ normalize_inputs=True,
20
+ infer_is_positive=True
21
+ ))
22
+
23
  def get_financial_plot(df, forecast_df, ticker, is_backtest=False):
24
  fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
25
  vertical_spacing=0.05, row_heights=[0.75, 0.25])
26
 
 
27
  display_df = df[:-30] if is_backtest else df
28
 
29
+ # 1. Historical Price
30
  fig.add_trace(go.Scatter(x=display_df.index, y=display_df['Close'], name='History', line=dict(color='#2962FF')), row=1, col=1)
31
 
32
+ # 2. Backtest Truth
33
  if is_backtest:
34
  fig.add_trace(go.Scatter(x=df.index[-30:], y=df['Close'][-30:], name='Actual', line=dict(color='#787b86', dash='dot')), row=1, col=1)
35
 
36
+ # 3. AI Forecast
37
+ fc_dates = pd.date_range(start=display_df.index[-1], periods=len(forecast_df)+1, freq='D')[1:]
38
+ fig.add_trace(go.Scatter(x=fc_dates, y=forecast_df, name='AI Forecast', line=dict(color='#F23645', width=3)), row=1, col=1)
 
39
 
40
  # 4. Volume
41
  fig.add_trace(go.Bar(x=df.index, y=df['Volume'], name='Volume', marker_color='rgba(38, 166, 154, 0.5)'), row=2, col=1)
 
46
  def run_analysis(ticker, horizon, mode):
47
  try:
48
  df = yf.download(ticker, period="2y")
49
+ if df.empty: return None, "⚠️ Ticker Not Found", None
50
 
51
  train_df = df[:-30] if mode == "Backtest (Reality Check)" else df
52
 
53
+ # New API uses simple array inputs
54
+ inputs = [train_df['Close'].values]
55
+
56
+ # Run inference using the v2.5 forecast method
57
+ point_forecast, _ = model.forecast(
58
+ inputs=inputs,
59
+ horizon=horizon if mode != "Backtest (Reality Check)" else 30
 
 
 
 
 
60
  )
 
61
 
62
+ forecast_values = point_forecast[0] # Get results for the first (only) input
63
+
64
+ # Save CSV
65
+ csv_path = f"{ticker}_forecast.csv"
66
+ pd.DataFrame({'Forecast': forecast_values}).to_csv(csv_path, index=False)
67
+
68
+ # Signals
69
  if mode != "Backtest (Reality Check)":
70
+ pct = ((forecast_values[-1] - train_df['Close'].iloc[-1]) / train_df['Close'].iloc[-1]) * 100
71
  signal = f"<h3 style='color: {'#00ff88' if pct > 0 else '#ff4444'};'>{ 'BULLISH' if pct > 0 else 'BEARISH' } ({pct:+.2f}%)</h3>"
72
  else:
73
+ acc = 100 - (abs(df['Close'].iloc[-1] - forecast_values[-1]) / df['Close'].iloc[-1] * 100)
74
+ signal = f"<h3 style='color: #FFD700;'>Accuracy: {acc:.1f}%</h3>"
75
 
76
+ return get_financial_plot(df, forecast_values, ticker, (mode == "Backtest (Reality Check)")), signal, csv_path
77
 
78
  except Exception as e:
79
+ return None, f"Runtime Error: {str(e)}", None
80
 
81
+ # --- UI ---
82
  with gr.Blocks(theme=gr.themes.Default(), css=".gradio-container {background-color: #000000}") as demo:
83
+ gr.HTML("<h2 style='color: #2962FF; text-align: center;'>G-TIMES 2.5 QUANT TERMINAL</h2>")
84
  with gr.Row():
85
  with gr.Column(scale=1):
86
  ticker_in = gr.Textbox(label="SYMBOL", value="NVDA")
87
  mode_in = gr.Radio(["Future Forecast", "Backtest (Reality Check)"], label="MODE", value="Future Forecast")
88
  days_in = gr.Slider(5, 128, value=30, label="Days")
89
+ btn = gr.Button("RUN ANALYSIS", variant="primary")
90
  result_box = gr.HTML()
91
+ file_output = gr.File(label="Download Forecast")
92
  with gr.Column(scale=4):
93
  plot_out = gr.Plot()
94
+ btn.click(run_analysis, [ticker_in, days_in, mode_in], [plot_out, result_box, file_output])
 
95
 
96
  demo.launch()