Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| import yfinance as yf | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from datetime import datetime, timedelta | |
| import plotly.graph_objects as go | |
| from torch import nn | |
| import warnings | |
| from bs4 import BeautifulSoup | |
| import requests | |
| from scipy.signal import savgol_filter | |
| import threading | |
| import time | |
| warnings.filterwarnings('ignore') | |
| class EnhancedStockPredictionModel(nn.Module): | |
| def __init__(self, input_dim, hidden_dim, num_layers, output_dim): | |
| super(EnhancedStockPredictionModel, self).__init__() | |
| self.hidden_dim = hidden_dim | |
| self.num_layers = num_layers | |
| # Enhanced CNN layers with batch normalization | |
| self.conv1 = nn.Conv1d(input_dim, 32, kernel_size=3, padding=1) | |
| self.bn1 = nn.BatchNorm1d(32) | |
| self.conv2 = nn.Conv1d(32, 64, kernel_size=3, padding=1) | |
| self.bn2 = nn.BatchNorm1d(64) | |
| # Attention mechanism | |
| self.attention = nn.MultiheadAttention(64, 4) | |
| # Bidirectional LSTM | |
| self.lstm = nn.LSTM(64, hidden_dim, num_layers, batch_first=True, bidirectional=True) | |
| # Advanced fully connected layers with dropout | |
| self.dropout = nn.Dropout(0.2) | |
| self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim) | |
| self.fc2 = nn.Linear(hidden_dim, output_dim) | |
| def forward(self, x): | |
| # CNN with batch normalization | |
| x = x.permute(0, 2, 1) | |
| x = self.bn1(torch.relu(self.conv1(x))) | |
| x = self.bn2(torch.relu(self.conv2(x))) | |
| # Reshape for attention | |
| x = x.permute(2, 0, 1) | |
| x, _ = self.attention(x, x, x) | |
| x = x.permute(1, 0, 2) | |
| # Bidirectional LSTM | |
| lstm_out, _ = self.lstm(x) | |
| # Get last output from both directions | |
| last_output = lstm_out[:, -1] | |
| # Fully connected layers with dropout | |
| x = self.dropout(torch.relu(self.fc1(last_output))) | |
| out = self.fc2(x) | |
| return out | |
| class EnhancedStockPredictor: | |
| def __init__(self): | |
| self.sentiment_model = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert") | |
| self.tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert") | |
| self.prediction_model = EnhancedStockPredictionModel( | |
| input_dim=8, # price, volume, sentiment, RSI, MACD, Signal, Bollinger, Volume_MA | |
| hidden_dim=128, | |
| num_layers=3, | |
| output_dim=1 | |
| ) | |
| # Cache for storing data | |
| self.cache = {} | |
| self.cache_lock = threading.Lock() | |
| def get_news_sentiment(self, ticker): | |
| try: | |
| url = f"https://finance.yahoo.com/quote/{ticker}/news" | |
| headers = {'User-Agent': 'Mozilla/5.0'} | |
| response = requests.get(url, headers=headers) | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| news_items = soup.find_all('h3', class_='Mb(5px)') | |
| news_text = ' '.join([item.text for item in news_items[:5]]) | |
| return self.analyze_sentiment(news_text) | |
| except: | |
| return 0.5 # Neutral sentiment if failed | |
| def calculate_technical_indicators(self, df): | |
| # Enhanced RSI | |
| delta = df['Close'].diff() | |
| gain = (delta.where(delta > 0, 0)).rolling(window=14).mean() | |
| loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean() | |
| rs = gain / loss | |
| df['RSI'] = 100 - (100 / (1 + rs)) | |
| # Enhanced MACD with signal smoothing | |
| exp1 = df['Close'].ewm(span=12, adjust=False).mean() | |
| exp2 = df['Close'].ewm(span=26, adjust=False).mean() | |
| df['MACD'] = exp1 - exp2 | |
| df['Signal'] = savgol_filter(df['MACD'].ewm(span=9, adjust=False).mean(), 5, 3) | |
| # Bollinger Bands | |
| df['MA20'] = df['Close'].rolling(window=20).mean() | |
| std = df['Close'].rolling(window=20).std() | |
| df['Bollinger_Upper'] = df['MA20'] + (std * 2) | |
| df['Bollinger_Lower'] = df['MA20'] - (std * 2) | |
| df['Bollinger'] = (df['Close'] - df['MA20']) / (std * 2) | |
| # Volume indicators | |
| df['Volume_MA'] = df['Volume'].rolling(window=20).mean() | |
| return df | |
| def get_stock_data(self, ticker, period='1y'): | |
| current_time = time.time() | |
| with self.cache_lock: | |
| if ticker in self.cache: | |
| cached_data, cached_time = self.cache[ticker] | |
| if current_time - cached_time < 300: # 5 minutes cache | |
| return cached_data, None | |
| try: | |
| stock = yf.Ticker(ticker) | |
| df = stock.history(period=period) | |
| df = self.calculate_technical_indicators(df) | |
| with self.cache_lock: | |
| self.cache[ticker] = (df, current_time) | |
| return df, None | |
| except Exception as e: | |
| return None, f"Error fetching data: {str(e)}" | |
| def analyze_sentiment(self, text): | |
| inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| outputs = self.sentiment_model(**inputs) | |
| probabilities = torch.softmax(outputs.logits, dim=1) | |
| return probabilities[0].tolist()[1] # Positive sentiment score | |
| def predict(self, ticker, news_text, prediction_days): | |
| df, error = self.get_stock_data(ticker) | |
| if error: | |
| return None, error | |
| # Combine manual news with scraped news | |
| scraped_sentiment = self.get_news_sentiment(ticker) | |
| manual_sentiment = self.analyze_sentiment(news_text) | |
| sentiment_value = (scraped_sentiment + manual_sentiment) / 2 | |
| features = torch.tensor(df[['Close', 'Volume', 'RSI', 'MACD', 'Signal', | |
| 'Bollinger', 'Volume_MA']].values, dtype=torch.float32) | |
| sentiment_column = torch.full((len(features), 1), sentiment_value) | |
| features = torch.cat([features, sentiment_column], dim=1) | |
| with torch.no_grad(): | |
| predictions = [] | |
| confidence_intervals = [] | |
| current_input = features[-30:].unsqueeze(0) | |
| for _ in range(prediction_days): | |
| prediction = self.prediction_model(current_input) | |
| base_prediction = prediction.item() | |
| # Calculate confidence interval | |
| std_dev = torch.std(current_input[0, :, 0]).item() | |
| confidence_intervals.append([ | |
| base_prediction - std_dev, | |
| base_prediction + std_dev | |
| ]) | |
| predictions.append(base_prediction) | |
| new_row = torch.cat([ | |
| torch.tensor([[ | |
| base_prediction, | |
| current_input[0, -1, 1], # Volume | |
| current_input[0, -1, 2], # RSI | |
| current_input[0, -1, 3], # MACD | |
| current_input[0, -1, 4], # Signal | |
| current_input[0, -1, 5], # Bollinger | |
| current_input[0, -1, 6], # Volume_MA | |
| sentiment_value | |
| ]]) | |
| ], dim=0) | |
| current_input = torch.cat([current_input[:, 1:, :], new_row.unsqueeze(0)], dim=1) | |
| return predictions, confidence_intervals, None | |
| def create_enhanced_prediction_plot(historical_data, predictions, confidence_intervals, ticker): | |
| last_date = historical_data.index[-1] | |
| future_dates = [last_date + timedelta(days=i+1) for i in range(len(predictions))] | |
| fig = go.Figure() | |
| # Historical data | |
| fig.add_trace(go.Scatter( | |
| x=historical_data.index, | |
| y=historical_data['Close'], | |
| name='Historical', | |
| line=dict(color='blue') | |
| )) | |
| # Predictions | |
| fig.add_trace(go.Scatter( | |
| x=future_dates, | |
| y=predictions, | |
| name='Prediction', | |
| line=dict(color='red', dash='dash') | |
| )) | |
| # Confidence intervals | |
| fig.add_trace(go.Scatter( | |
| x=future_dates + future_dates[::-1], | |
| y=[ci[0] for ci in confidence_intervals] + [ci[1] for ci in confidence_intervals][::-1], | |
| fill='toself', | |
| fillcolor='rgba(255,0,0,0.1)', | |
| line=dict(color='rgba(255,0,0,0)'), | |
| name='Confidence Interval' | |
| )) | |
| # Technical indicators | |
| fig.add_trace(go.Scatter( | |
| x=historical_data.index, | |
| y=historical_data['MA20'], | |
| name='20-day MA', | |
| line=dict(color='green', dash='dot') | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=historical_data.index, | |
| y=historical_data['Bollinger_Upper'], | |
| name='Bollinger Upper', | |
| line=dict(color='gray', dash='dot') | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=historical_data.index, | |
| y=historical_data['Bollinger_Lower'], | |
| name='Bollinger Lower', | |
| line=dict(color='gray', dash='dot') | |
| )) | |
| fig.update_layout( | |
| title=f'{ticker} Stock Price Prediction with Technical Indicators', | |
| xaxis_title='Date', | |
| yaxis_title='Price', | |
| hovermode='x', | |
| showlegend=True, | |
| template='plotly_dark' | |
| ) | |
| return fig | |
| def predict_stock(ticker, news_text, prediction_days): | |
| predictor = EnhancedStockPredictor() | |
| predictions, confidence_intervals, error = predictor.predict(ticker, news_text, prediction_days) | |
| if error: | |
| return f"Error: {error}", None | |
| historical_data, error = predictor.get_stock_data(ticker) | |
| if error: | |
| return f"Error: {error}", None | |
| plot = create_enhanced_prediction_plot(historical_data, predictions, confidence_intervals, ticker) | |
| # Calculate additional metrics | |
| current_price = historical_data['Close'].iloc[-1] | |
| predicted_price = predictions[0] | |
| percent_change = ((predicted_price - current_price) / current_price) * 100 | |
| rsi = historical_data['RSI'].iloc[-1] | |
| macd = historical_data['MACD'].iloc[-1] | |
| analysis = f""" | |
| Current Price: ${current_price:.2f} | |
| Next Day Prediction: ${predicted_price:.2f} ({percent_change:+.2f}%) | |
| RSI: {rsi:.2f} ({'Overbought' if rsi > 70 else 'Oversold' if rsi < 30 else 'Neutral'}) | |
| MACD: {macd:.2f} ({'Bullish' if macd > 0 else 'Bearish'}) | |
| Confidence Interval: ${confidence_intervals[0][0]:.2f} to ${confidence_intervals[0][1]:.2f} | |
| """ | |
| return analysis, plot | |
| # Create enhanced Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_stock, | |
| inputs=[ | |
| gr.Textbox(label="Stock Ticker (e.g., AAPL)"), | |
| gr.Textbox(label="Recent News or Analysis (Optional)", lines=3), | |
| gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days", value=7) | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Analysis"), | |
| gr.Plot(label="Advanced Prediction Plot") | |
| ], | |
| title="🚀 Advanced Stock Price Prediction Platform", | |
| description="Enter a stock ticker, recent news (optional), and prediction period to get comprehensive stock analysis and forecasts.", | |
| theme="default" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |