test-predictor / app.py
dkasjdlwhf's picture
Update app.py
bd54cb1 verified
import gradio as gr
import torch
import yfinance as yf
import numpy as np
from chronos import ChronosPipeline
# Initialize model once for the app
print("Loading Chronos model...")
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-base",
device_map="auto",
torch_dtype=torch.bfloat16,
)
def predict_stock_numeric(ticker_symbol, timeframe):
try:
tf_map = {
"5 mins": ("1d", "5m"), "15 mins": ("1d", "15m"),
"1 hr": ("1mo", "1h"), "1 day": ("1y", "1d"),
"1 week": ("2y", "1wk"), "1 month": ("5y", "1mo"),
"1 year": ("max", "1y")
}
period, interval = tf_map.get(timeframe, ("2y", "1d"))
df = yf.download(ticker_symbol, period=period, interval=interval)
if df.empty: return "Error: Ticker not found."
prices_series = df['Close'].values.flatten()
context_tensor = torch.tensor(prices_series)
forecast_result = pipeline.predict(context_tensor, 1)
median_pred = np.median(forecast_result[0].numpy(), axis=0)[0]
current_price = prices_series[-1]
change = median_pred - current_price
percent_change = (change / current_price) * 100
return (
f"Ticker: {ticker_symbol}\n"
f"Timeframe: {timeframe}\n"
f"Last Close: {current_price:.2f}\n"
f"Predicted Next Close: {median_pred:.2f}\n"
f"Expected Move: {change:+.2f} ({percent_change:+.2f}%)"
)
except Exception as e:
return f"Error: {str(e)}"
with gr.Blocks() as demo:
gr.Markdown("# Strix Technologies Predictor")
with gr.Row():
ticker_input = gr.Dropdown(
choices=["AAPL", "TSLA", "GOOGL", "MSFT", "AMZN", "NVDA", "BTC-USD", "ETH-USD"],
label="Select or Type Ticker", value="AAPL", allow_custom_value=True
)
timeframe_input = gr.Dropdown(
choices=["5 mins", "15 mins", "1 hr", "1 day", "1 week", "1 month"],
label="Timeframe", value="1 day"
)
btn = gr.Button("Predict Next Close")
output_text = gr.Textbox(label="Prediction Result", lines=6)
btn.click(fn=predict_stock_numeric, inputs=[ticker_input, timeframe_input], outputs=output_text)
if __name__ == '__main__':
demo.launch()