dkasjdlwhf commited on
Commit
9e66a31
·
verified ·
1 Parent(s): 8c1fd3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -26
app.py CHANGED
@@ -2,35 +2,33 @@ import gradio as gr
2
  import torch
3
  import yfinance as yf
4
  import numpy as np
 
 
 
 
 
 
 
 
 
5
 
6
  def predict_stock_numeric(ticker_symbol, timeframe):
7
  try:
8
- # Map timeframe to yfinance period/interval
9
  tf_map = {
10
- "5 mins": ("1d", "5m"),
11
- "15 mins": ("1d", "15m"),
12
- "1 hr": ("1mo", "1h"),
13
- "1 day": ("1y", "1d"),
14
- "1 week": ("2y", "1wk"),
15
- "1 month": ("5y", "1mo"),
16
  "1 year": ("max", "1y")
17
  }
18
-
19
  period, interval = tf_map.get(timeframe, ("2y", "1d"))
20
 
21
- # 1. Fetch Data
22
  df = yf.download(ticker_symbol, period=period, interval=interval)
23
- if df.empty:
24
- return "Error: Ticker not found or no data for this timeframe."
25
 
26
  prices_series = df['Close'].values.flatten()
27
-
28
- # 2. Inference
29
  context_tensor = torch.tensor(prices_series)
30
- pred_len = 1
31
- forecast_result = pipeline.predict(context_tensor, pred_len)
32
 
33
- # 3. Process Result
34
  median_pred = np.median(forecast_result[0].numpy(), axis=0)[0]
35
  current_price = prices_series[-1]
36
  change = median_pred - current_price
@@ -46,26 +44,20 @@ def predict_stock_numeric(ticker_symbol, timeframe):
46
  except Exception as e:
47
  return f"Error: {str(e)}"
48
 
49
- # Create GUI
50
  with gr.Blocks() as demo:
51
  gr.Markdown("# Strix Technologies Predictor")
52
  with gr.Row():
53
- # Dropdown with allow_custom_value lets users pick or type
54
  ticker_input = gr.Dropdown(
55
  choices=["AAPL", "TSLA", "GOOGL", "MSFT", "AMZN", "NVDA", "BTC-USD", "ETH-USD"],
56
- label="Select or Type Ticker",
57
- value="AAPL",
58
- allow_custom_value=True
59
  )
60
  timeframe_input = gr.Dropdown(
61
  choices=["5 mins", "15 mins", "1 hr", "1 day", "1 week", "1 month", "1 year"],
62
- label="Timeframe",
63
- value="1 day"
64
  )
65
  btn = gr.Button("Predict Next Close")
66
  output_text = gr.Textbox(label="Prediction Result", lines=6)
67
-
68
  btn.click(fn=predict_stock_numeric, inputs=[ticker_input, timeframe_input], outputs=output_text)
69
 
70
- # share=True creates a temporary public link (72h)
71
- demo.launch(debug=True, share=True)
 
2
  import torch
3
  import yfinance as yf
4
  import numpy as np
5
+ from chronos import ChronosPipeline
6
+
7
+ # Initialize model once for the app
8
+ print("Loading Chronos model...")
9
+ pipeline = ChronosPipeline.from_pretrained(
10
+ "amazon/chronos-t5-base",
11
+ device_map="auto",
12
+ torch_dtype=torch.bfloat16,
13
+ )
14
 
15
  def predict_stock_numeric(ticker_symbol, timeframe):
16
  try:
 
17
  tf_map = {
18
+ "5 mins": ("1d", "5m"), "15 mins": ("1d", "15m"),
19
+ "1 hr": ("1mo", "1h"), "1 day": ("1y", "1d"),
20
+ "1 week": ("2y", "1wk"), "1 month": ("5y", "1mo"),
 
 
 
21
  "1 year": ("max", "1y")
22
  }
 
23
  period, interval = tf_map.get(timeframe, ("2y", "1d"))
24
 
 
25
  df = yf.download(ticker_symbol, period=period, interval=interval)
26
+ if df.empty: return "Error: Ticker not found."
 
27
 
28
  prices_series = df['Close'].values.flatten()
 
 
29
  context_tensor = torch.tensor(prices_series)
30
+ forecast_result = pipeline.predict(context_tensor, 1)
 
31
 
 
32
  median_pred = np.median(forecast_result[0].numpy(), axis=0)[0]
33
  current_price = prices_series[-1]
34
  change = median_pred - current_price
 
44
  except Exception as e:
45
  return f"Error: {str(e)}"
46
 
 
47
  with gr.Blocks() as demo:
48
  gr.Markdown("# Strix Technologies Predictor")
49
  with gr.Row():
 
50
  ticker_input = gr.Dropdown(
51
  choices=["AAPL", "TSLA", "GOOGL", "MSFT", "AMZN", "NVDA", "BTC-USD", "ETH-USD"],
52
+ label="Select or Type Ticker", value="AAPL", allow_custom_value=True
 
 
53
  )
54
  timeframe_input = gr.Dropdown(
55
  choices=["5 mins", "15 mins", "1 hr", "1 day", "1 week", "1 month", "1 year"],
56
+ label="Timeframe", value="1 day"
 
57
  )
58
  btn = gr.Button("Predict Next Close")
59
  output_text = gr.Textbox(label="Prediction Result", lines=6)
 
60
  btn.click(fn=predict_stock_numeric, inputs=[ticker_input, timeframe_input], outputs=output_text)
61
 
62
+ if __name__ == '__main__':
63
+ demo.launch()