pyroleli commited on
Commit
dce0ea3
·
verified ·
1 Parent(s): c33df41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -82
app.py CHANGED
@@ -1,109 +1,90 @@
1
  import gradio as gr
2
  import pandas as pd
3
- import numpy as np
4
  import yfinance as yf
5
  import plotly.graph_objects as go
6
- from plotly.subplots import make_subplots
7
- import timesfm
8
 
9
- # --- SDK v2.5 INITIALIZATION ---
10
- # Load the 200M parameter model specifically for PyTorch
11
- model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
12
- "google/timesfm-2.5-200m-pytorch"
 
 
13
  )
14
 
15
- # FIXED: Parameters are now moved into ForecastConfig
16
- model.compile(timesfm.ForecastConfig(
17
- max_context=1024, # v2.5 now supports up to 16k
18
- max_horizon=256,
19
- normalize_inputs=True,
20
- infer_is_positive=True
21
- ))
22
-
23
- def generate_professional_chart(df, forecast_data, ticker, is_backtest=False):
24
- fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
25
- vertical_spacing=0.03, row_heights=[0.8, 0.2])
26
-
27
  hist_df = df[:-30] if is_backtest else df
28
 
29
- # Historical Trend
30
- fig.add_trace(go.Scatter(x=hist_df.index, y=hist_df['Close'], name='Market Price',
31
- line=dict(color='#2962FF', width=2)), row=1, col=1)
32
 
33
- # Reality Check (Dotted)
34
- if is_backtest:
35
- fig.add_trace(go.Scatter(x=df.index[-30:], y=df['Close'][-30:], name='Actual Path',
36
- line=dict(color='#787b86', dash='dot')), row=1, col=1)
37
-
38
- # AI Forecast Line
39
- fc_dates = pd.date_range(start=hist_df.index[-1], periods=len(forecast_data)+1, freq='B')[1:]
40
- fig.add_trace(go.Scatter(x=fc_dates, y=forecast_data, name='AI Forecast',
41
- line=dict(color='#F23645', width=3)), row=1, col=1)
 
 
 
 
42
 
43
- # Volume Bars
44
- fig.add_trace(go.Bar(x=df.index, y=df['Volume'], name='Volume',
45
- marker_color='rgba(120, 123, 134, 0.3)'), row=2, col=1)
46
 
47
- fig.update_layout(template='plotly_dark', paper_bgcolor='#131722', plot_bgcolor='#131722',
48
- margin=dict(l=50, r=50, t=30, b=30), hovermode="x unified")
49
- fig.update_yaxes(side='right', gridcolor='#2a2e39')
50
- fig.update_xaxes(gridcolor='#2a2e39')
51
  return fig
52
 
53
- def run_terminal(ticker, horizon, mode):
54
  try:
55
- # 1. Data Ingestion
56
- data = yf.download(ticker, period="2y")
57
- if data.empty: return None, "⚠️ SYMBOL NOT FOUND", None
58
 
59
- train_data = data[:-30] if mode == "Backtest Mode" else data
 
60
 
61
- # 2. SDK Inference (Array-based input for v2.5)
62
- inputs = [train_data['Close'].values]
63
- point_forecast, _ = model.forecast(
64
- inputs=inputs,
65
- horizon=30 if mode == "Backtest Mode" else horizon
66
- )
67
- prediction = point_forecast[0]
68
-
69
- # 3. Financial Metrics
70
- pct_change = ((prediction[-1] - train_data['Close'].iloc[-1]) / train_data['Close'].iloc[-1]) * 100
71
 
72
- if mode == "Future Forecast":
73
- signal = f"<h2 style='color: {'#00ff88' if pct_change > 0 else '#ff4444'}; text-align: center;'>" \
74
- f"{'BULLISH' if pct_change > 0 else 'BEARISH'} ({pct_change:+.2f}%)</h2>"
75
- else:
76
- # Calculate Backtest Accuracy
77
- accuracy = 100 - abs((data['Close'].iloc[-1] - prediction[-1]) / data['Close'].iloc[-1] * 100)
78
- signal = f"<h2 style='color: #FFD700; text-align: center;'>AI ACCURACY: {accuracy:.1f}%</h2>"
79
-
80
- # 4. Generate Report
81
- report_path = f"{ticker}_AI_Report.csv"
82
- pd.DataFrame({'Forecast_Price': prediction}).to_csv(report_path, index=False)
83
 
84
- chart = generate_professional_chart(data, prediction, ticker, (mode == "Backtest Mode"))
85
- return chart, signal, report_path
86
-
 
87
  except Exception as e:
88
- return None, f"<div style='color:red;'>API Error: {str(e)}</div>", None
89
 
90
- # --- TERMINAL UI ---
91
- with gr.Blocks(title="G-TIMES QUANT 2.5", theme=gr.themes.Base()) as demo:
92
- gr.HTML("<div style='background-color:#131722; padding:20px; border-bottom:3px solid #2962FF; text-align:center;'>"
93
- "<h1 style='color:white; letter-spacing:3px;'>G-TIMES <span style='color:#2962FF;'>QUANT 2.5</span></h1></div>")
94
-
95
  with gr.Row():
96
  with gr.Column(scale=1):
97
- ticker_in = gr.Textbox(label="TICKER", value="TSLA")
98
- mode_in = gr.Radio(["Future Forecast", "Backtest Mode"], label="STRATEGY", value="Future Forecast")
99
- days_in = gr.Slider(7, 128, value=30, label="HORIZON")
100
- btn = gr.Button("RUN QUANT ANALYSIS", variant="primary")
101
- status_out = gr.HTML()
102
- file_out = gr.File(label="CSV EXPORT")
103
-
104
  with gr.Column(scale=4):
105
- plot_out = gr.Plot()
106
 
107
- btn.click(run_terminal, [ticker_in, days_in, mode_in], [plot_out, status_out, file_out])
108
 
109
  demo.launch()
 
1
  import gradio as gr
2
  import pandas as pd
3
+ import torch
4
  import yfinance as yf
5
  import plotly.graph_objects as go
6
+ from chronos import ChronosPipeline
 
7
 
8
+ # --- LOAD AMAZON CHRONOS-2 ---
9
+ # This model treats time-series as a language, making it more flexible
10
+ pipeline = ChronosPipeline.from_pretrained(
11
+ "amazon/chronos-t5-base",
12
+ device_map="cpu", # Use "cuda" if you have a GPU
13
+ torch_dtype=torch.float32,
14
  )
15
 
16
+ def get_pro_chart(df, forecast_samples, ticker, is_backtest=False):
17
+ # Chronos provides a distribution (multiple possible futures)
18
+ # We take the median (50th percentile) for the main line
19
+ low, median, high = np.quantile(forecast_samples, [0.1, 0.5, 0.9], axis=0)
20
+
21
+ fig = go.Figure()
 
 
 
 
 
 
22
  hist_df = df[:-30] if is_backtest else df
23
 
24
+ # 1. Historical Data
25
+ fig.add_trace(go.Scatter(x=hist_df.index, y=hist_df['Close'], name='History', line=dict(color='#2962FF')))
 
26
 
27
+ # 2. Future Forecast (Median)
28
+ last_date = hist_df.index[-1]
29
+ fc_dates = pd.date_range(start=last_date, periods=len(median)+1, freq='B')[1:]
30
+
31
+ # Confidence Interval (The "Cloud")
32
+ fig.add_trace(go.Scatter(
33
+ x=list(fc_dates) + list(fc_dates)[::-1],
34
+ y=list(high) + list(low)[::-1],
35
+ fill='toself', fillcolor='rgba(242, 54, 69, 0.1)',
36
+ line=dict(color='rgba(255,255,255,0)'), name='Confidence Range'
37
+ ))
38
+
39
+ fig.add_trace(go.Scatter(x=fc_dates, y=median, name='AI Median Forecast', line=dict(color='#F23645', width=3)))
40
 
41
+ if is_backtest:
42
+ fig.add_trace(go.Scatter(x=df.index[-30:], y=df['Close'][-30:], name='Actual', line=dict(color='#787b86', dash='dot')))
 
43
 
44
+ fig.update_layout(template='plotly_dark', paper_bgcolor='#131722', plot_bgcolor='#131722', margin=dict(l=10, r=10, t=40, b=10))
 
 
 
45
  return fig
46
 
47
+ def analyze(ticker, horizon, mode):
48
  try:
49
+ df = yf.download(ticker, period="2y")
50
+ if df.empty: return None, "Symbol not found", None
 
51
 
52
+ context_data = df[:-30] if mode == "Backtest" else df
53
+ context_tensor = torch.tensor(context_data['Close'].values)
54
 
55
+ # Chronos Inference
56
+ # h = horizon, num_samples = 20 (to get a confidence range)
57
+ forecast = pipeline.predict(context_tensor, horizon if mode != "Backtest" else 30, num_samples=20)
58
+ forecast_np = forecast.numpy()[0] # Shape: [samples, horizon]
59
+ median_fc = np.median(forecast_np, axis=0)
 
 
 
 
 
60
 
61
+ # Results & Signals
62
+ change = ((median_fc[-1] - context_data['Close'].iloc[-1]) / context_data['Close'].iloc[-1]) * 100
63
+ signal = f"<h2 style='color: {'#00ff88' if change > 0 else '#ff4444'};'>{'BULLISH' if change > 0 else 'BEARISH'} ({change:+.2f}%)</h2>"
 
 
 
 
 
 
 
 
64
 
65
+ csv_path = "forecast.csv"
66
+ pd.DataFrame(forecast_np.T).to_csv(csv_path)
67
+
68
+ return get_pro_chart(df, forecast_np, ticker, mode == "Backtest"), signal, csv_path
69
  except Exception as e:
70
+ return None, f"Error: {str(e)}", None
71
 
72
+ import numpy as np # Needed for quantiles
73
+
74
+ # --- GRADIO UI ---
75
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
76
+ gr.HTML("<h1 style='text-align:center; color:#2962FF;'>CHRONOS-2 ANALYTICS</h1>")
77
  with gr.Row():
78
  with gr.Column(scale=1):
79
+ t_in = gr.Textbox(label="Ticker", value="AAPL")
80
+ m_in = gr.Radio(["Future Forecast", "Backtest"], label="Mode", value="Future Forecast")
81
+ h_in = gr.Slider(7, 90, value=30, label="Days")
82
+ btn = gr.Button("RUN AI MODEL", variant="primary")
83
+ msg = gr.HTML()
84
+ file = gr.File(label="Export Data")
 
85
  with gr.Column(scale=4):
86
+ plot = gr.Plot()
87
 
88
+ btn.click(analyze, [t_in, h_in, m_in], [plot, msg, file])
89
 
90
  demo.launch()