Spaces:
Sleeping
Sleeping
update app
Browse files
app.py
CHANGED
|
@@ -3,7 +3,7 @@ import yfinance as yf
|
|
| 3 |
import pandas as pd
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 7 |
from datetime import datetime, timedelta
|
| 8 |
import plotly.graph_objects as go
|
| 9 |
import plotly.express as px
|
|
@@ -29,14 +29,14 @@ from config import IDX_STOCKS, TECHNICAL_INDICATORS, PREDICTION_CONFIG
|
|
| 29 |
# Load Chronos-Bolt model
|
| 30 |
@spaces.GPU(duration=120)
|
| 31 |
def load_model():
|
| 32 |
-
"""Load the Amazon Chronos-Bolt model for time series forecasting
|
| 33 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 34 |
"amazon/chronos-bolt-base",
|
| 35 |
torch_dtype=torch.bfloat16,
|
| 36 |
device_map="auto",
|
| 37 |
trust_remote_code=True
|
| 38 |
)
|
| 39 |
-
# CRITICAL FIX: Skip tokenizer loading
|
| 40 |
tokenizer = None
|
| 41 |
return model, tokenizer
|
| 42 |
|
|
@@ -73,15 +73,14 @@ def analyze_stock(symbol, prediction_days=30):
|
|
| 73 |
signals = generate_trading_signals(data, indicators)
|
| 74 |
|
| 75 |
# Make predictions using Chronos-Bolt
|
| 76 |
-
# Passing the (now None) tokenizer argument to maintain compatibility with utils.py signature
|
| 77 |
predictions = predict_prices(data, model, tokenizer, prediction_days)
|
| 78 |
|
| 79 |
# Create charts
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
|
| 84 |
-
return fundamental_info, indicators, signals,
|
| 85 |
|
| 86 |
def create_ui():
|
| 87 |
"""Create the Gradio interface"""
|
|
@@ -212,7 +211,7 @@ def create_ui():
|
|
| 212 |
|
| 213 |
# Event handlers
|
| 214 |
def update_analysis(symbol, pred_days):
|
| 215 |
-
fundamental_info, indicators, signals,
|
| 216 |
|
| 217 |
if fundamental_info is None:
|
| 218 |
return {
|
|
@@ -242,6 +241,8 @@ def create_ui():
|
|
| 242 |
}
|
| 243 |
|
| 244 |
# Format outputs
|
|
|
|
|
|
|
| 245 |
return {
|
| 246 |
company_name: fundamental_info.get('name', 'N/A'),
|
| 247 |
current_price: fundamental_info.get('current_price', 0),
|
|
@@ -263,9 +264,9 @@ def create_ui():
|
|
| 263 |
predicted_low: indicators.get('predictions', {}).get('low_30d', 0),
|
| 264 |
predicted_change: indicators.get('predictions', {}).get('change_pct', 0),
|
| 265 |
prediction_summary: indicators.get('predictions', {}).get('summary', ''),
|
| 266 |
-
price_chart:
|
| 267 |
-
technical_chart:
|
| 268 |
-
prediction_chart:
|
| 269 |
}
|
| 270 |
|
| 271 |
analyze_btn.click(
|
|
|
|
| 3 |
import pandas as pd
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 7 |
from datetime import datetime, timedelta
|
| 8 |
import plotly.graph_objects as go
|
| 9 |
import plotly.express as px
|
|
|
|
| 29 |
# Load Chronos-Bolt model
|
| 30 |
@spaces.GPU(duration=120)
|
| 31 |
def load_model():
|
| 32 |
+
"""Load the Amazon Chronos-Bolt model for time series forecasting"""
|
| 33 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 34 |
"amazon/chronos-bolt-base",
|
| 35 |
torch_dtype=torch.bfloat16,
|
| 36 |
device_map="auto",
|
| 37 |
trust_remote_code=True
|
| 38 |
)
|
| 39 |
+
# CRITICAL FIX: Skip tokenizer loading
|
| 40 |
tokenizer = None
|
| 41 |
return model, tokenizer
|
| 42 |
|
|
|
|
| 73 |
signals = generate_trading_signals(data, indicators)
|
| 74 |
|
| 75 |
# Make predictions using Chronos-Bolt
|
|
|
|
| 76 |
predictions = predict_prices(data, model, tokenizer, prediction_days)
|
| 77 |
|
| 78 |
# Create charts
|
| 79 |
+
fig_price = create_price_chart(data, indicators) # RENAMED
|
| 80 |
+
fig_technical = create_technical_chart(data, indicators) # RENAMED
|
| 81 |
+
fig_prediction = create_prediction_chart(data, predictions) # RENAMED
|
| 82 |
|
| 83 |
+
return fundamental_info, indicators, signals, fig_price, fig_technical, fig_prediction # Returning renamed local variables
|
| 84 |
|
| 85 |
def create_ui():
|
| 86 |
"""Create the Gradio interface"""
|
|
|
|
| 211 |
|
| 212 |
# Event handlers
|
| 213 |
def update_analysis(symbol, pred_days):
|
| 214 |
+
fundamental_info, indicators, signals, fig_price, fig_technical, fig_prediction = analyze_stock(symbol, pred_days) # Receives renamed local variables
|
| 215 |
|
| 216 |
if fundamental_info is None:
|
| 217 |
return {
|
|
|
|
| 241 |
}
|
| 242 |
|
| 243 |
# Format outputs
|
| 244 |
+
# CRITICAL FIX: Use the Gradio component objects (left) as keys
|
| 245 |
+
# and the Plotly Figure objects (right) as values.
|
| 246 |
return {
|
| 247 |
company_name: fundamental_info.get('name', 'N/A'),
|
| 248 |
current_price: fundamental_info.get('current_price', 0),
|
|
|
|
| 264 |
predicted_low: indicators.get('predictions', {}).get('low_30d', 0),
|
| 265 |
predicted_change: indicators.get('predictions', {}).get('change_pct', 0),
|
| 266 |
prediction_summary: indicators.get('predictions', {}).get('summary', ''),
|
| 267 |
+
price_chart: fig_price,
|
| 268 |
+
technical_chart: fig_technical,
|
| 269 |
+
prediction_chart: fig_prediction
|
| 270 |
}
|
| 271 |
|
| 272 |
analyze_btn.click(
|