pyroleli commited on
Commit
a47157b
·
verified ·
1 Parent(s): e25c096

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import yfinance as yf
7
+ from chronos import Chronos2Pipeline
8
+
9
+ # 1. Load the Chronos-2 Model
10
+ # Using 'small' for faster performance on Free Tier CPUs
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ pipeline = Chronos2Pipeline.from_pretrained(
13
+ "amazon/chronos-2",
14
+ device_map=device,
15
+ torch_dtype=torch.float32,
16
+ )
17
+
18
+ def forecast_stock(ticker, forecast_days):
19
+ # 2. Fetch Historical Data
20
+ data = yf.download(ticker, period="1y")
21
+ if data.empty:
22
+ return None, "Error: Ticker not found."
23
+
24
+ # Extract closing prices
25
+ df = data[['Close']].reset_index()
26
+ context = torch.tensor(df['Close'].values)
27
+
28
+ # 3. Perform Inference
29
+ # Chronos produces probabilistic forecasts (multiple samples)
30
+ forecast = pipeline.predict(context, forecast_days) # shape: [1, num_samples, forecast_days]
31
+
32
+ # Calculate quantiles for the "dotted" area
33
+ # 0.5 is the median (main line), 0.1 and 0.9 create the confidence interval
34
+ low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
35
+
36
+ # 4. Plotting
37
+ plt.figure(figsize=(12, 6))
38
+
39
+ # Previous Data (Solid Line)
40
+ history_indices = np.arange(len(context))
41
+ plt.plot(history_indices, context, color='royalblue', label="Historical Price", linewidth=2)
42
+
43
+ # Forecast Data (Dotted Section)
44
+ forecast_indices = np.arange(len(context), len(context) + forecast_days)
45
+
46
+ # Median Forecast (Dashed)
47
+ plt.plot(forecast_indices, median, color='tomato', linestyle='--', label="Median Forecast")
48
+
49
+ # Shaded Uncertainty Area
50
+ plt.fill_between(forecast_indices, low, high, color='tomato', alpha=0.2, label="80% Confidence Interval")
51
+
52
+ plt.title(f"Chronos-2 Forecast for {ticker}", fontsize=14)
53
+ plt.xlabel("Days (Relative)")
54
+ plt.ylabel("Price (USD)")
55
+ plt.legend()
56
+ plt.grid(True, alpha=0.3)
57
+
58
+ return plt.gcf()
59
+
60
+ # 5. Build Gradio Interface
61
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
62
+ gr.Markdown("# 📈 Amazon Chronos-2 Financial Forecaster")
63
+ gr.Markdown("Input a stock ticker (e.g., AAPL, BTC-USD) to see the AI predict future trends.")
64
+
65
+ with gr.Row():
66
+ ticker_input = gr.Textbox(label="Stock Ticker", value="AAPL")
67
+ days_input = gr.Slider(minimum=5, maximum=60, step=1, label="Forecast Horizon (Days)", value=30)
68
+
69
+ forecast_btn = gr.Button("Generate Forecast", variant="primary")
70
+ plot_output = gr.Plot(label="Time Series Forecast")
71
+
72
+ forecast_btn.click(
73
+ fn=forecast_stock,
74
+ inputs=[ticker_input, days_input],
75
+ outputs=plot_output
76
+ )
77
+
78
+ demo.launch()