Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import pandas as pd
|
| 3 |
-
import
|
| 4 |
from prophet import Prophet
|
| 5 |
import plotly.graph_objs as go
|
| 6 |
import math
|
| 7 |
-
import
|
| 8 |
-
from
|
| 9 |
-
from
|
|
|
|
| 10 |
|
| 11 |
# --- Replace with your Alpha Vantage API key ---
|
| 12 |
ALPHA_VANTAGE_API_KEY = "YOUR_ALPHA_VANTAGE_API_KEY" # <--- Replace with your key
|
|
@@ -16,6 +18,63 @@ CRYPTO_SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
|
| 16 |
STOCK_SYMBOLS = ["AAPL", "MSFT"]
|
| 17 |
INTERVAL_OPTIONS = ["1h", "60min"] # Consistent naming
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
# --- Technical Analysis Functions ---
|
| 20 |
def calculate_technical_indicators(df):
|
| 21 |
"""Calculates RSI, MACD, and Bollinger Bands."""
|
|
@@ -164,31 +223,60 @@ def create_forecast_plot(forecast_df):
|
|
| 164 |
)
|
| 165 |
return fig
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
# --- Main Prediction and Display Function ---
|
| 168 |
-
def analyze_market(market_type, symbol, interval, forecast_steps, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale):
|
| 169 |
"""Main function to orchestrate data fetching, analysis, and prediction."""
|
| 170 |
df = pd.DataFrame()
|
| 171 |
error_message = ""
|
| 172 |
-
|
| 173 |
# 1. Data Fetching
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
df = fetch_crypto_data(symbol)
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
except Exception as e:
|
| 183 |
-
error_message = f"Error fetching stock data: {e}"
|
| 184 |
-
else:
|
| 185 |
-
error_message = "Invalid market type selected."
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
| 189 |
|
|
|
|
|
|
|
|
|
|
| 190 |
# 2. Preprocessing & Technical Analysis
|
| 191 |
-
df["timestamp"] = pd.to_datetime(df["timestamp"])
|
| 192 |
numeric_cols = ["open", "high", "low", "close", "volume"]
|
| 193 |
df[numeric_cols] = df[numeric_cols].astype(float)
|
| 194 |
df = calculate_technical_indicators(df)
|
|
@@ -209,7 +297,7 @@ def analyze_market(market_type, symbol, interval, forecast_steps, daily_seasonal
|
|
| 209 |
|
| 210 |
if prophet_error:
|
| 211 |
error_message = f"Prophet Error: {prophet_error}"
|
| 212 |
-
return None, None, None, None, None, "", error_message #Return error
|
| 213 |
|
| 214 |
forecast_plot = create_forecast_plot(forecast_df)
|
| 215 |
|
|
@@ -233,8 +321,8 @@ def analyze_market(market_type, symbol, interval, forecast_steps, daily_seasonal
|
|
| 233 |
# Prepare forecast data for the Dataframe output
|
| 234 |
forecast_df_display = forecast_df.loc[:, ["ds", "yhat", "yhat_lower", "yhat_upper"]].copy()
|
| 235 |
forecast_df_display.rename(columns={"ds": "Date", "yhat": "Forecast", "yhat_lower": "Lower Bound", "yhat_upper": "Upper Bound"}, inplace=True)
|
|
|
|
| 236 |
|
| 237 |
-
return forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df_display, growth_label, error_message #Return error
|
| 238 |
# --- Gradio Interface ---
|
| 239 |
with gr.Blocks(theme=gr.themes.Base()) as demo:
|
| 240 |
gr.Markdown("# Market Analysis and Prediction")
|
|
@@ -250,6 +338,7 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
|
|
| 250 |
yearly_box = gr.Checkbox(label="Yearly Seasonality", value=False)
|
| 251 |
seasonality_mode_dd = gr.Dropdown(label="Seasonality Mode", choices=["additive", "multiplicative"], value="additive")
|
| 252 |
changepoint_scale_slider = gr.Slider(label="Changepoint Prior Scale", minimum=0.01, maximum=1.0, step=0.01, value=0.05)
|
|
|
|
| 253 |
|
| 254 |
with gr.Column():
|
| 255 |
forecast_plot = gr.Plot(label="Price Forecast")
|
|
@@ -260,6 +349,7 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
|
|
| 260 |
macd_plot = gr.Plot(label="MACD")
|
| 261 |
forecast_df = gr.Dataframe(label="Forecast Data", headers=["Date", "Forecast", "Lower Bound", "Upper Bound"])
|
| 262 |
growth_label_output = gr.Label(label="Explosive Growth Prediction") # Added for prediction.
|
|
|
|
| 263 |
|
| 264 |
# Event Listener to update symbol dropdown based on market type
|
| 265 |
def update_symbol_choices(market_type):
|
|
@@ -283,8 +373,9 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
|
|
| 283 |
yearly_box,
|
| 284 |
seasonality_mode_dd,
|
| 285 |
changepoint_scale_slider,
|
|
|
|
| 286 |
],
|
| 287 |
-
outputs=[forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df, growth_label_output]
|
| 288 |
)
|
| 289 |
|
| 290 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
import gradio as gr
|
| 3 |
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
from prophet import Prophet
|
| 6 |
import plotly.graph_objs as go
|
| 7 |
import math
|
| 8 |
+
import requests # For API calls
|
| 9 |
+
from sklearn.ensemble import RandomForestClassifier # For the model
|
| 10 |
+
from textblob import TextBlob # For sentiment analysis (optional)
|
| 11 |
+
|
| 12 |
|
| 13 |
# --- Replace with your Alpha Vantage API key ---
|
| 14 |
ALPHA_VANTAGE_API_KEY = "YOUR_ALPHA_VANTAGE_API_KEY" # <--- Replace with your key
|
|
|
|
| 18 |
STOCK_SYMBOLS = ["AAPL", "MSFT"]
|
| 19 |
INTERVAL_OPTIONS = ["1h", "60min"] # Consistent naming
|
| 20 |
|
| 21 |
+
# --- Data Fetching Functions ---
|
| 22 |
+
def fetch_crypto_data(symbol, interval="1h", limit=100):
|
| 23 |
+
"""Fetch crypto market data from Binance."""
|
| 24 |
+
try:
|
| 25 |
+
url = f"https://api.binance.com/api/v3/klines"
|
| 26 |
+
params = {"symbol": symbol, "interval": interval, "limit": limit}
|
| 27 |
+
response = requests.get(url, params=params)
|
| 28 |
+
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
|
| 29 |
+
data = response.json()
|
| 30 |
+
df = pd.DataFrame(data, columns=["timestamp", "open", "high", "low", "close", "volume"])
|
| 31 |
+
df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
|
| 32 |
+
df[["open", "high", "low", "close", "volume"]] = df[["open", "high", "low", "close", "volume"]].astype(float)
|
| 33 |
+
return df.dropna()
|
| 34 |
+
except requests.exceptions.RequestException as e:
|
| 35 |
+
raise Exception(f"Error fetching crypto data: {e}")
|
| 36 |
+
except (ValueError, KeyError) as e:
|
| 37 |
+
raise Exception(f"Error parsing crypto data: {e}")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def fetch_stock_data(symbol, interval="60min"):
|
| 41 |
+
"""Fetch stock market data from Alpha Vantage."""
|
| 42 |
+
try:
|
| 43 |
+
url = f"https://www.alphavantage.co/query"
|
| 44 |
+
params = {"function": "TIME_SERIES_INTRADAY", "symbol": symbol, "interval": interval,
|
| 45 |
+
"apikey": ALPHA_VANTAGE_API_KEY}
|
| 46 |
+
response = requests.get(url, params=params)
|
| 47 |
+
response.raise_for_status() # Raise HTTPError for bad responses
|
| 48 |
+
data = response.json()
|
| 49 |
+
if "Time Series (60min)" in data: # Check the JSON for the correct key
|
| 50 |
+
time_series_data = data["Time Series (60min)"]
|
| 51 |
+
df = pd.DataFrame(time_series_data).T.reset_index()
|
| 52 |
+
df.columns = ["timestamp", "open", "high", "low", "close", "volume"] # Standardize
|
| 53 |
+
df["timestamp"] = pd.to_datetime(df["timestamp"])
|
| 54 |
+
df[["open", "high", "low", "close", "volume"]] = df[["open", "high", "low", "close", "volume"]].astype(float)
|
| 55 |
+
return df.dropna()
|
| 56 |
+
else:
|
| 57 |
+
raise Exception(f"Error: Could not retrieve stock data. Check API Key, Symbol, and Interval. Response: {data}")
|
| 58 |
+
|
| 59 |
+
except requests.exceptions.RequestException as e:
|
| 60 |
+
raise Exception(f"Error fetching stock data: {e}")
|
| 61 |
+
except (ValueError, KeyError) as e:
|
| 62 |
+
raise Exception(f"Error parsing stock data: {e}")
|
| 63 |
+
|
| 64 |
+
def fetch_sentiment_data(keyword): # Placeholder - replace with a real sentiment analysis method
|
| 65 |
+
"""Analyze sentiment from social media (placeholder)."""
|
| 66 |
+
try:
|
| 67 |
+
tweets = [
|
| 68 |
+
f"{keyword} is going to moon!",
|
| 69 |
+
f"I hate {keyword}, it's trash!",
|
| 70 |
+
f"{keyword} is amazing!"
|
| 71 |
+
]
|
| 72 |
+
sentiments = [TextBlob(tweet).sentiment.polarity for tweet in tweets]
|
| 73 |
+
return sum(sentiments) / len(sentiments) / len(sentiments) if sentiments else 0 # Avoid ZeroDivisionError
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"Sentiment analysis error: {e}")
|
| 76 |
+
return 0
|
| 77 |
+
|
| 78 |
# --- Technical Analysis Functions ---
|
| 79 |
def calculate_technical_indicators(df):
|
| 80 |
"""Calculates RSI, MACD, and Bollinger Bands."""
|
|
|
|
| 223 |
)
|
| 224 |
return fig
|
| 225 |
|
| 226 |
+
# --- Model Training and Prediction ---
|
| 227 |
+
model = RandomForestClassifier() # Moved here
|
| 228 |
+
|
| 229 |
+
def train_model(df):
|
| 230 |
+
"""Train the AI model."""
|
| 231 |
+
if df.empty:
|
| 232 |
+
return # Or raise an exception, or return a default model.
|
| 233 |
+
df["target"] = (df["close"].pct_change() > 0.05).astype(int) # Target: 1 if price increased by >5%
|
| 234 |
+
features = df[["close", "volume"]].dropna()
|
| 235 |
+
target = df["target"].dropna()
|
| 236 |
+
if not features.empty and not target.empty: #check data is available
|
| 237 |
+
model.fit(features, target)
|
| 238 |
+
else:
|
| 239 |
+
print("Not enough data for model training.")
|
| 240 |
+
|
| 241 |
+
def predict_growth(latest_data):
|
| 242 |
+
"""Predict explosive growth."""
|
| 243 |
+
if not hasattr(model, 'estimators_') or len(model.estimators_) == 0: # Check if model is trained
|
| 244 |
+
return [0] # Or return an error message, or a default value
|
| 245 |
+
|
| 246 |
+
try:
|
| 247 |
+
prediction = model.predict(latest_data.reshape(1, -1))
|
| 248 |
+
return prediction
|
| 249 |
+
except Exception as e:
|
| 250 |
+
print(f"Prediction error: {e}")
|
| 251 |
+
return [0]
|
| 252 |
+
|
| 253 |
# --- Main Prediction and Display Function ---
|
| 254 |
+
def analyze_market(market_type, symbol, interval, forecast_steps, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale, sentiment_keyword=""):
|
| 255 |
"""Main function to orchestrate data fetching, analysis, and prediction."""
|
| 256 |
df = pd.DataFrame()
|
| 257 |
error_message = ""
|
| 258 |
+
sentiment_score = 0 # Initialize sentiment score
|
| 259 |
# 1. Data Fetching
|
| 260 |
+
try:
|
| 261 |
+
if market_type == "Crypto":
|
| 262 |
+
df = fetch_crypto_data(symbol, interval=interval)
|
| 263 |
+
elif market_type == "Stock":
|
| 264 |
+
df = fetch_stock_data(symbol, interval=interval)
|
| 265 |
+
else:
|
| 266 |
+
error_message = "Invalid market type selected."
|
| 267 |
+
return None, None, None, None, None, "", error_message, 0 # Also return sentiment
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
+
if sentiment_keyword: # If a keyword for sentiment is entered:
|
| 270 |
+
sentiment_score = fetch_sentiment_data(sentiment_keyword)
|
| 271 |
+
except Exception as e:
|
| 272 |
+
error_message = f"Data Fetching Error: {e}"
|
| 273 |
+
return None, None, None, None, None, "", error_message, 0 #Return error + sentiment
|
| 274 |
|
| 275 |
+
if df.empty:
|
| 276 |
+
error_message = "No data fetched."
|
| 277 |
+
return None, None, None, None, None, "", error_message, 0 # Return empty + sentiment
|
| 278 |
# 2. Preprocessing & Technical Analysis
|
| 279 |
+
df["timestamp"] = pd.to_datetime(df["timestamp"])
|
| 280 |
numeric_cols = ["open", "high", "low", "close", "volume"]
|
| 281 |
df[numeric_cols] = df[numeric_cols].astype(float)
|
| 282 |
df = calculate_technical_indicators(df)
|
|
|
|
| 297 |
|
| 298 |
if prophet_error:
|
| 299 |
error_message = f"Prophet Error: {prophet_error}"
|
| 300 |
+
return None, None, None, None, None, "", error_message, sentiment_score #Return prophet error
|
| 301 |
|
| 302 |
forecast_plot = create_forecast_plot(forecast_df)
|
| 303 |
|
|
|
|
| 321 |
# Prepare forecast data for the Dataframe output
|
| 322 |
forecast_df_display = forecast_df.loc[:, ["ds", "yhat", "yhat_lower", "yhat_upper"]].copy()
|
| 323 |
forecast_df_display.rename(columns={"ds": "Date", "yhat": "Forecast", "yhat_lower": "Lower Bound", "yhat_upper": "Upper Bound"}, inplace=True)
|
| 324 |
+
return forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df_display, growth_label, error_message, sentiment_score
|
| 325 |
|
|
|
|
| 326 |
# --- Gradio Interface ---
|
| 327 |
with gr.Blocks(theme=gr.themes.Base()) as demo:
|
| 328 |
gr.Markdown("# Market Analysis and Prediction")
|
|
|
|
| 338 |
yearly_box = gr.Checkbox(label="Yearly Seasonality", value=False)
|
| 339 |
seasonality_mode_dd = gr.Dropdown(label="Seasonality Mode", choices=["additive", "multiplicative"], value="additive")
|
| 340 |
changepoint_scale_slider = gr.Slider(label="Changepoint Prior Scale", minimum=0.01, maximum=1.0, step=0.01, value=0.05)
|
| 341 |
+
sentiment_keyword_txt = gr.Textbox(label="Sentiment Keyword (optional)") #Add Sentiment input
|
| 342 |
|
| 343 |
with gr.Column():
|
| 344 |
forecast_plot = gr.Plot(label="Price Forecast")
|
|
|
|
| 349 |
macd_plot = gr.Plot(label="MACD")
|
| 350 |
forecast_df = gr.Dataframe(label="Forecast Data", headers=["Date", "Forecast", "Lower Bound", "Upper Bound"])
|
| 351 |
growth_label_output = gr.Label(label="Explosive Growth Prediction") # Added for prediction.
|
| 352 |
+
sentiment_label_output = gr.Number(label="Sentiment Score") # Added for sentiment output
|
| 353 |
|
| 354 |
# Event Listener to update symbol dropdown based on market type
|
| 355 |
def update_symbol_choices(market_type):
|
|
|
|
| 373 |
yearly_box,
|
| 374 |
seasonality_mode_dd,
|
| 375 |
changepoint_scale_slider,
|
| 376 |
+
sentiment_keyword_txt, # Add sentiment keyword to the input
|
| 377 |
],
|
| 378 |
+
outputs=[forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df, growth_label_output, gr.Label(label="Error Message"), sentiment_label_output] # Add sentiment score to the output
|
| 379 |
)
|
| 380 |
|
| 381 |
if __name__ == "__main__":
|