omniverse1 commited on
Commit
561d1d4
·
verified ·
1 Parent(s): 47b8dda

update app

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -3,7 +3,7 @@ import yfinance as yf
3
  import pandas as pd
4
  import numpy as np
5
  import torch
6
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # Keep import for dependency check
7
  from datetime import datetime, timedelta
8
  import plotly.graph_objects as go
9
  import plotly.express as px
@@ -29,14 +29,14 @@ from config import IDX_STOCKS, TECHNICAL_INDICATORS, PREDICTION_CONFIG
29
  # Load Chronos-Bolt model
30
  @spaces.GPU(duration=120)
31
  def load_model():
32
- """Load the Amazon Chronos-Bolt model for time series forecasting. Tokenizer loading is skipped to bypass fatal error."""
33
  model = AutoModelForSeq2SeqLM.from_pretrained(
34
  "amazon/chronos-bolt-base",
35
  torch_dtype=torch.bfloat16,
36
  device_map="auto",
37
  trust_remote_code=True
38
  )
39
- # CRITICAL FIX: Skip tokenizer loading (it is not used in predict_prices anyway)
40
  tokenizer = None
41
  return model, tokenizer
42
 
@@ -73,15 +73,14 @@ def analyze_stock(symbol, prediction_days=30):
73
  signals = generate_trading_signals(data, indicators)
74
 
75
  # Make predictions using Chronos-Bolt
76
- # Passing the (now None) tokenizer argument to maintain compatibility with utils.py signature
77
  predictions = predict_prices(data, model, tokenizer, prediction_days)
78
 
79
  # Create charts
80
- price_chart = create_price_chart(data, indicators)
81
- technical_chart = create_technical_chart(data, indicators)
82
- prediction_chart = create_prediction_chart(data, predictions)
83
 
84
- return fundamental_info, indicators, signals, price_chart, technical_chart, prediction_chart
85
 
86
  def create_ui():
87
  """Create the Gradio interface"""
@@ -212,7 +211,7 @@ def create_ui():
212
 
213
  # Event handlers
214
  def update_analysis(symbol, pred_days):
215
- fundamental_info, indicators, signals, price_chart, technical_chart, prediction_chart = analyze_stock(symbol, pred_days)
216
 
217
  if fundamental_info is None:
218
  return {
@@ -242,6 +241,8 @@ def create_ui():
242
  }
243
 
244
  # Format outputs
 
 
245
  return {
246
  company_name: fundamental_info.get('name', 'N/A'),
247
  current_price: fundamental_info.get('current_price', 0),
@@ -263,9 +264,9 @@ def create_ui():
263
  predicted_low: indicators.get('predictions', {}).get('low_30d', 0),
264
  predicted_change: indicators.get('predictions', {}).get('change_pct', 0),
265
  prediction_summary: indicators.get('predictions', {}).get('summary', ''),
266
- price_chart: price_chart,
267
- technical_chart: technical_chart,
268
- prediction_chart: prediction_chart
269
  }
270
 
271
  analyze_btn.click(
 
3
  import pandas as pd
4
  import numpy as np
5
  import torch
6
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
  from datetime import datetime, timedelta
8
  import plotly.graph_objects as go
9
  import plotly.express as px
 
29
  # Load Chronos-Bolt model
30
  @spaces.GPU(duration=120)
31
  def load_model():
32
+ """Load the Amazon Chronos-Bolt model for time series forecasting"""
33
  model = AutoModelForSeq2SeqLM.from_pretrained(
34
  "amazon/chronos-bolt-base",
35
  torch_dtype=torch.bfloat16,
36
  device_map="auto",
37
  trust_remote_code=True
38
  )
39
+ # CRITICAL FIX: Skip tokenizer loading
40
  tokenizer = None
41
  return model, tokenizer
42
 
 
73
  signals = generate_trading_signals(data, indicators)
74
 
75
  # Make predictions using Chronos-Bolt
 
76
  predictions = predict_prices(data, model, tokenizer, prediction_days)
77
 
78
  # Create charts
79
+ fig_price = create_price_chart(data, indicators) # RENAMED
80
+ fig_technical = create_technical_chart(data, indicators) # RENAMED
81
+ fig_prediction = create_prediction_chart(data, predictions) # RENAMED
82
 
83
+ return fundamental_info, indicators, signals, fig_price, fig_technical, fig_prediction # Returning renamed local variables
84
 
85
  def create_ui():
86
  """Create the Gradio interface"""
 
211
 
212
  # Event handlers
213
  def update_analysis(symbol, pred_days):
214
+ fundamental_info, indicators, signals, fig_price, fig_technical, fig_prediction = analyze_stock(symbol, pred_days) # Receives renamed local variables
215
 
216
  if fundamental_info is None:
217
  return {
 
241
  }
242
 
243
  # Format outputs
244
+ # CRITICAL FIX: Use the Gradio component objects (left) as keys
245
+ # and the Plotly Figure objects (right) as values.
246
  return {
247
  company_name: fundamental_info.get('name', 'N/A'),
248
  current_price: fundamental_info.get('current_price', 0),
 
264
  predicted_low: indicators.get('predictions', {}).get('low_30d', 0),
265
  predicted_change: indicators.get('predictions', {}).get('change_pct', 0),
266
  prediction_summary: indicators.get('predictions', {}).get('summary', ''),
267
+ price_chart: fig_price,
268
+ technical_chart: fig_technical,
269
+ prediction_chart: fig_prediction
270
  }
271
 
272
  analyze_btn.click(