Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,42 +4,38 @@ import numpy as np
|
|
| 4 |
import yfinance as yf
|
| 5 |
import plotly.graph_objects as go
|
| 6 |
from plotly.subplots import make_subplots
|
| 7 |
-
|
| 8 |
|
| 9 |
-
# ---
|
| 10 |
-
#
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
backend="cpu", # JAX handles CPU/GPU automatically
|
| 14 |
-
per_core_batch_size=32,
|
| 15 |
-
horizon_len=128,
|
| 16 |
-
context_len=512,
|
| 17 |
-
num_layers=20,
|
| 18 |
-
model_dims=1280,
|
| 19 |
-
),
|
| 20 |
-
checkpoint=TimesFmCheckpoint(
|
| 21 |
-
huggingface_repo_id="google/timesfm-1.0-200m" # Original JAX weights
|
| 22 |
-
),
|
| 23 |
)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def get_financial_plot(df, forecast_df, ticker, is_backtest=False):
|
| 26 |
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
|
| 27 |
vertical_spacing=0.05, row_heights=[0.75, 0.25])
|
| 28 |
|
| 29 |
-
# Plotting logic remains consistent for professional look
|
| 30 |
display_df = df[:-30] if is_backtest else df
|
| 31 |
|
| 32 |
-
# 1. Price
|
| 33 |
fig.add_trace(go.Scatter(x=display_df.index, y=display_df['Close'], name='History', line=dict(color='#2962FF')), row=1, col=1)
|
| 34 |
|
| 35 |
-
# 2. Backtest
|
| 36 |
if is_backtest:
|
| 37 |
fig.add_trace(go.Scatter(x=df.index[-30:], y=df['Close'][-30:], name='Actual', line=dict(color='#787b86', dash='dot')), row=1, col=1)
|
| 38 |
|
| 39 |
-
# 3. Forecast
|
| 40 |
-
fc_dates =
|
| 41 |
-
|
| 42 |
-
fig.add_trace(go.Scatter(x=fc_dates, y=fc_vals, name='AI Forecast', line=dict(color='#F23645', width=3)), row=1, col=1)
|
| 43 |
|
| 44 |
# 4. Volume
|
| 45 |
fig.add_trace(go.Bar(x=df.index, y=df['Volume'], name='Volume', marker_color='rgba(38, 166, 154, 0.5)'), row=2, col=1)
|
|
@@ -50,51 +46,51 @@ def get_financial_plot(df, forecast_df, ticker, is_backtest=False):
|
|
| 50 |
def run_analysis(ticker, horizon, mode):
|
| 51 |
try:
|
| 52 |
df = yf.download(ticker, period="2y")
|
| 53 |
-
if df.empty: return None, "⚠️ Ticker Not Found"
|
| 54 |
|
| 55 |
train_df = df[:-30] if mode == "Backtest (Reality Check)" else df
|
| 56 |
|
| 57 |
-
#
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
# Forecast Execution (Native JAX backend)
|
| 65 |
-
forecast_df, _ = tfm.forecast_on_df(
|
| 66 |
-
inputs=input_df,
|
| 67 |
-
freq="D",
|
| 68 |
-
value_name="y"
|
| 69 |
)
|
| 70 |
-
forecast_df = forecast_df.head(horizon if mode != "Backtest (Reality Check)" else 30)
|
| 71 |
|
| 72 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
if mode != "Backtest (Reality Check)":
|
| 74 |
-
pct = ((
|
| 75 |
signal = f"<h3 style='color: {'#00ff88' if pct > 0 else '#ff4444'};'>{ 'BULLISH' if pct > 0 else 'BEARISH' } ({pct:+.2f}%)</h3>"
|
| 76 |
else:
|
| 77 |
-
acc = 100 - (abs(df['Close'].iloc[-1] -
|
| 78 |
-
signal = f"<h3 style='color: #FFD700;'>
|
| 79 |
|
| 80 |
-
return get_financial_plot(df,
|
| 81 |
|
| 82 |
except Exception as e:
|
| 83 |
-
return None, f"Runtime Error: {str(e)}"
|
| 84 |
|
| 85 |
-
# --- UI
|
| 86 |
with gr.Blocks(theme=gr.themes.Default(), css=".gradio-container {background-color: #000000}") as demo:
|
| 87 |
-
gr.HTML("<h2 style='color: #2962FF; text-align: center;'>G-TIMES
|
| 88 |
with gr.Row():
|
| 89 |
with gr.Column(scale=1):
|
| 90 |
ticker_in = gr.Textbox(label="SYMBOL", value="NVDA")
|
| 91 |
mode_in = gr.Radio(["Future Forecast", "Backtest (Reality Check)"], label="MODE", value="Future Forecast")
|
| 92 |
days_in = gr.Slider(5, 128, value=30, label="Days")
|
| 93 |
-
btn = gr.Button("RUN
|
| 94 |
result_box = gr.HTML()
|
|
|
|
| 95 |
with gr.Column(scale=4):
|
| 96 |
plot_out = gr.Plot()
|
| 97 |
-
|
| 98 |
-
btn.click(run_analysis, [ticker_in, days_in, mode_in], [plot_out, result_box])
|
| 99 |
|
| 100 |
demo.launch()
|
|
|
|
| 4 |
import yfinance as yf
|
| 5 |
import plotly.graph_objects as go
|
| 6 |
from plotly.subplots import make_subplots
|
| 7 |
+
import timesfm
|
| 8 |
|
| 9 |
+
# --- UPDATED INITIALIZATION (v2.5 API) ---
|
| 10 |
+
# We use the new torch-specific class and ForecastConfig object
|
| 11 |
+
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
|
| 12 |
+
"google/timesfm-2.5-200m-pytorch"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
)
|
| 14 |
|
| 15 |
+
# Configure the model using the new ForecastConfig
|
| 16 |
+
model.compile(timesfm.ForecastConfig(
|
| 17 |
+
max_context=512,
|
| 18 |
+
max_horizon=128,
|
| 19 |
+
normalize_inputs=True,
|
| 20 |
+
infer_is_positive=True
|
| 21 |
+
))
|
| 22 |
+
|
| 23 |
def get_financial_plot(df, forecast_df, ticker, is_backtest=False):
|
| 24 |
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
|
| 25 |
vertical_spacing=0.05, row_heights=[0.75, 0.25])
|
| 26 |
|
|
|
|
| 27 |
display_df = df[:-30] if is_backtest else df
|
| 28 |
|
| 29 |
+
# 1. Historical Price
|
| 30 |
fig.add_trace(go.Scatter(x=display_df.index, y=display_df['Close'], name='History', line=dict(color='#2962FF')), row=1, col=1)
|
| 31 |
|
| 32 |
+
# 2. Backtest Truth
|
| 33 |
if is_backtest:
|
| 34 |
fig.add_trace(go.Scatter(x=df.index[-30:], y=df['Close'][-30:], name='Actual', line=dict(color='#787b86', dash='dot')), row=1, col=1)
|
| 35 |
|
| 36 |
+
# 3. AI Forecast
|
| 37 |
+
fc_dates = pd.date_range(start=display_df.index[-1], periods=len(forecast_df)+1, freq='D')[1:]
|
| 38 |
+
fig.add_trace(go.Scatter(x=fc_dates, y=forecast_df, name='AI Forecast', line=dict(color='#F23645', width=3)), row=1, col=1)
|
|
|
|
| 39 |
|
| 40 |
# 4. Volume
|
| 41 |
fig.add_trace(go.Bar(x=df.index, y=df['Volume'], name='Volume', marker_color='rgba(38, 166, 154, 0.5)'), row=2, col=1)
|
|
|
|
| 46 |
def run_analysis(ticker, horizon, mode):
|
| 47 |
try:
|
| 48 |
df = yf.download(ticker, period="2y")
|
| 49 |
+
if df.empty: return None, "⚠️ Ticker Not Found", None
|
| 50 |
|
| 51 |
train_df = df[:-30] if mode == "Backtest (Reality Check)" else df
|
| 52 |
|
| 53 |
+
# New API uses simple array inputs
|
| 54 |
+
inputs = [train_df['Close'].values]
|
| 55 |
+
|
| 56 |
+
# Run inference using the v2.5 forecast method
|
| 57 |
+
point_forecast, _ = model.forecast(
|
| 58 |
+
inputs=inputs,
|
| 59 |
+
horizon=horizon if mode != "Backtest (Reality Check)" else 30
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
)
|
|
|
|
| 61 |
|
| 62 |
+
forecast_values = point_forecast[0] # Get results for the first (only) input
|
| 63 |
+
|
| 64 |
+
# Save CSV
|
| 65 |
+
csv_path = f"{ticker}_forecast.csv"
|
| 66 |
+
pd.DataFrame({'Forecast': forecast_values}).to_csv(csv_path, index=False)
|
| 67 |
+
|
| 68 |
+
# Signals
|
| 69 |
if mode != "Backtest (Reality Check)":
|
| 70 |
+
pct = ((forecast_values[-1] - train_df['Close'].iloc[-1]) / train_df['Close'].iloc[-1]) * 100
|
| 71 |
signal = f"<h3 style='color: {'#00ff88' if pct > 0 else '#ff4444'};'>{ 'BULLISH' if pct > 0 else 'BEARISH' } ({pct:+.2f}%)</h3>"
|
| 72 |
else:
|
| 73 |
+
acc = 100 - (abs(df['Close'].iloc[-1] - forecast_values[-1]) / df['Close'].iloc[-1] * 100)
|
| 74 |
+
signal = f"<h3 style='color: #FFD700;'>Accuracy: {acc:.1f}%</h3>"
|
| 75 |
|
| 76 |
+
return get_financial_plot(df, forecast_values, ticker, (mode == "Backtest (Reality Check)")), signal, csv_path
|
| 77 |
|
| 78 |
except Exception as e:
|
| 79 |
+
return None, f"Runtime Error: {str(e)}", None
|
| 80 |
|
| 81 |
+
# --- UI ---
|
| 82 |
with gr.Blocks(theme=gr.themes.Default(), css=".gradio-container {background-color: #000000}") as demo:
|
| 83 |
+
gr.HTML("<h2 style='color: #2962FF; text-align: center;'>G-TIMES 2.5 QUANT TERMINAL</h2>")
|
| 84 |
with gr.Row():
|
| 85 |
with gr.Column(scale=1):
|
| 86 |
ticker_in = gr.Textbox(label="SYMBOL", value="NVDA")
|
| 87 |
mode_in = gr.Radio(["Future Forecast", "Backtest (Reality Check)"], label="MODE", value="Future Forecast")
|
| 88 |
days_in = gr.Slider(5, 128, value=30, label="Days")
|
| 89 |
+
btn = gr.Button("RUN ANALYSIS", variant="primary")
|
| 90 |
result_box = gr.HTML()
|
| 91 |
+
file_output = gr.File(label="Download Forecast")
|
| 92 |
with gr.Column(scale=4):
|
| 93 |
plot_out = gr.Plot()
|
| 94 |
+
btn.click(run_analysis, [ticker_in, days_in, mode_in], [plot_out, result_box, file_output])
|
|
|
|
| 95 |
|
| 96 |
demo.launch()
|