pyroleli commited on
Commit
5e546fa
·
verified ·
1 Parent(s): 77f52f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -55
app.py CHANGED
@@ -2,100 +2,120 @@ import gradio as gr
2
  import pandas as pd
3
  import torch
4
  import plotly.graph_objects as go
 
5
  import numpy as np
6
  import yfinance as yf
7
  from timesfm import TimesFm
8
 
9
- # Initialize Google TimesFM
10
- # We use the 200M model which is optimized for Hugging Face Free Tier
11
  tfm = TimesFm(
12
- context_len=512,
13
- horizon_len=128,
14
  input_patch_len=32,
15
  output_patch_len=128,
16
  num_layers=20,
17
  model_dims=1280,
18
- backend="cpu", # Change to "gpu" if your Space has a GPU
19
  )
20
 
21
  tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")
22
 
23
  def get_tradingview_plot(df, forecast_df, ticker):
24
- fig = go.Figure()
 
 
 
25
 
26
- # Historical Price (Professional Blue)
27
  fig.add_trace(go.Scatter(
28
  x=df.index, y=df['Close'],
29
- name='Historical',
30
  line=dict(color='#2962FF', width=2)
31
- ))
32
 
33
- # TimesFM Forecast (Dotted Red)
34
  fig.add_trace(go.Scatter(
35
  x=forecast_df['ds'], y=forecast_df['timesfm'],
36
- name='TimesFM Forecast',
37
  line=dict(color='#F23645', width=3, dash='dot')
38
- ))
39
 
40
- # Add TradingView Styling
 
 
 
 
 
 
 
 
41
  fig.update_layout(
42
  template='plotly_dark',
43
  hovermode="x unified",
44
  paper_bgcolor='#131722',
45
  plot_bgcolor='#131722',
 
46
  margin=dict(l=10, r=10, t=50, b=10),
47
- xaxis=dict(showgrid=True, gridcolor='#2a2e39'),
48
- yaxis=dict(showgrid=True, gridcolor='#2a2e39', side='right'),
49
- legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01)
50
  )
 
 
 
 
51
  return fig
52
 
53
  def run_analysis(ticker, horizon):
54
- # Fetch Data
55
- df = yf.download(ticker, period="1y")
56
- if df.empty: return None, "Invalid Ticker"
57
-
58
- # TimesFM expects a specific format: unique_id, ds (date), y (value)
59
- # We create a simple dataframe for the model
60
- input_df = pd.DataFrame({
61
- 'unique_id': [ticker],
62
- 'ds': df.index,
63
- 'y': df['Close'].values
64
- })
65
-
66
- # Execute Forecast
67
- # freq: 0 for daily, 1 for weekly, 2 for monthly
68
- forecast_df, _ = tfm.forecast_on_df(
69
- inputs=input_df,
70
- freq="D",
71
- value_name="y"
72
- )
73
 
74
- # Filter forecast to the user-selected horizon
75
- forecast_df = forecast_df.head(horizon)
76
-
77
- # Calculate professional metrics
78
- current_price = df['Close'].iloc[-1]
79
- final_pred = forecast_df['timesfm'].iloc[-1]
80
- change_pct = ((final_pred - current_price) / current_price) * 100
81
-
82
- status = "BULLISH" if change_pct > 0 else "BEARISH"
83
- signal_text = f"{status} ({change_pct:+.2f}%)"
84
 
85
- fig = get_tradingview_plot(df, forecast_df, ticker)
86
-
87
- return fig, signal_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Professional UI Template
90
- with gr.Blocks(theme=gr.themes.Default(), css=".gradio-container {background-color: #131722;}") as demo:
91
- gr.Markdown("<h1 style='text-align: center; color: white;'>G-TIMES QUANT TERMINAL</h1>")
92
 
93
  with gr.Row():
94
  with gr.Column(scale=1):
95
- ticker_input = gr.Textbox(label="Ticker Symbol", value="NVDA")
96
- horizon_slider = gr.Slider(5, 128, value=30, label="Forecast Horizon")
97
- analyze_btn = gr.Button("RUN MODELS", variant="primary")
98
- output_signal = gr.Label(label="Market Signal")
 
 
 
 
99
 
100
  with gr.Column(scale=4):
101
  plot_output = gr.Plot()
 
2
  import pandas as pd
3
  import torch
4
  import plotly.graph_objects as go
5
+ from plotly.subplots import make_subplots
6
  import numpy as np
7
  import yfinance as yf
8
  from timesfm import TimesFm
9
 
10
+ # Initialize Google TimesFM with updated API arguments
11
+ # Note: 'context_len' changed to 'context_length', etc.
12
  tfm = TimesFm(
13
+ context_length=512,
14
+ horizon_length=128,
15
  input_patch_len=32,
16
  output_patch_len=128,
17
  num_layers=20,
18
  model_dims=1280,
19
+ backend="cpu",
20
  )
21
 
22
  tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")
23
 
24
  def get_tradingview_plot(df, forecast_df, ticker):
25
+ # Create subplots: Chart on top (80%), Volume on bottom (20%)
26
+ fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
27
+ vertical_spacing=0.03, subplot_titles=(f'{ticker} Analysis', ''),
28
+ row_width=[0.2, 0.8])
29
 
30
+ # 1. Main Price Line
31
  fig.add_trace(go.Scatter(
32
  x=df.index, y=df['Close'],
33
+ name='Historical Close',
34
  line=dict(color='#2962FF', width=2)
35
+ ), row=1, col=1)
36
 
37
+ # 2. TimesFM Forecast (Dotted Projection)
38
  fig.add_trace(go.Scatter(
39
  x=forecast_df['ds'], y=forecast_df['timesfm'],
40
+ name='AI Forecast',
41
  line=dict(color='#F23645', width=3, dash='dot')
42
+ ), row=1, col=1)
43
 
44
+ # 3. Volume Profile (Bars)
45
+ fig.add_trace(go.Bar(
46
+ x=df.index, y=df['Volume'],
47
+ name='Volume',
48
+ marker_color='#26a69a',
49
+ opacity=0.5
50
+ ), row=2, col=1)
51
+
52
+ # Professional Dark Theme Styling
53
  fig.update_layout(
54
  template='plotly_dark',
55
  hovermode="x unified",
56
  paper_bgcolor='#131722',
57
  plot_bgcolor='#131722',
58
+ showlegend=True,
59
  margin=dict(l=10, r=10, t=50, b=10),
60
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
 
 
61
  )
62
+
63
+ fig.update_xaxes(showgrid=True, gridcolor='#2a2e39', rangeslider_visible=False)
64
+ fig.update_yaxes(showgrid=True, gridcolor='#2a2e39', side='right')
65
+
66
  return fig
67
 
68
  def run_analysis(ticker, horizon):
69
+ try:
70
+ df = yf.download(ticker, period="1y")
71
+ if df.empty: return None, "⚠️ Invalid Ticker"
72
+
73
+ # Format for TimesFM
74
+ input_df = pd.DataFrame({
75
+ 'unique_id': [ticker],
76
+ 'ds': df.index,
77
+ 'y': df['Close'].values
78
+ })
 
 
 
 
 
 
 
 
 
79
 
80
+ # Execute Forecast
81
+ forecast_df, _ = tfm.forecast_on_df(
82
+ inputs=input_df,
83
+ freq="D",
84
+ value_name="y"
85
+ )
 
 
 
 
86
 
87
+ forecast_df = forecast_df.head(horizon)
88
+
89
+ # Calculate Signals
90
+ current_price = df['Close'].iloc[-1]
91
+ final_pred = forecast_df['timesfm'].iloc[-1]
92
+ change_pct = ((final_pred - current_price) / current_price) * 100
93
+
94
+ signal = "STRONG BUY" if change_pct > 5 else "BUY" if change_pct > 0 else "SELL"
95
+ signal_color = "#00ff88" if "BUY" in signal else "#ff4444"
96
+
97
+ fig = get_tradingview_plot(df, forecast_df, ticker)
98
+
99
+ html_signal = f"<h2 style='color: {signal_color}; text-align: center;'>{signal} ({change_pct:+.2f}%)</h2>"
100
+
101
+ return fig, html_signal
102
+ except Exception as e:
103
+ return None, f"Error: {str(e)}"
104
 
105
+ # Professional Layout
106
+ with gr.Blocks(theme=gr.themes.Default(), css=".gradio-container {background-color: #131722; color: white;}") as demo:
107
+ gr.HTML("<h1 style='text-align: center; color: #2962FF; font-family: sans-serif;'>G-TIMES QUANT TERMINAL</h1>")
108
 
109
  with gr.Row():
110
  with gr.Column(scale=1):
111
+ with gr.Box():
112
+ ticker_input = gr.Textbox(label="Ticker Symbol", value="NVDA")
113
+ horizon_slider = gr.Slider(5, 128, value=30, label="Forecast Horizon (Days)")
114
+ analyze_btn = gr.Button("RUN QUANT ANALYSIS", variant="primary")
115
+
116
+ output_signal = gr.HTML(label="Market Signal")
117
+ gr.Markdown("---")
118
+ gr.Markdown("### Terminal Info\n- **Model**: Google TimesFM-1.0\n- **Backend**: CPU PyTorch\n- **Freq**: Daily Close")
119
 
120
  with gr.Column(scale=4):
121
  plot_output = gr.Plot()