Spaces:
Build error
Build error
| # app.py | |
| import gradio as gr | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| import yfinance as yf | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from datetime import datetime, timedelta | |
| import plotly.graph_objects as go | |
| from torch import nn | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Custom CNN-LSTM Model | |
| class StockPredictionModel(nn.Module): | |
| def __init__(self, input_dim, hidden_dim, num_layers, output_dim): | |
| super(StockPredictionModel, self).__init__() | |
| self.hidden_dim = hidden_dim | |
| self.num_layers = num_layers | |
| # CNN layers | |
| self.conv1 = nn.Conv1d(input_dim, 32, kernel_size=3, padding=1) | |
| self.conv2 = nn.Conv1d(32, 64, kernel_size=3, padding=1) | |
| # LSTM layers | |
| self.lstm = nn.LSTM(64, hidden_dim, num_layers, batch_first=True) | |
| # Fully connected layer | |
| self.fc = nn.Linear(hidden_dim, output_dim) | |
| def forward(self, x): | |
| # CNN | |
| x = x.permute(0, 2, 1) # Reshape for CNN | |
| x = torch.relu(self.conv1(x)) | |
| x = torch.relu(self.conv2(x)) | |
| x = x.permute(0, 2, 1) # Reshape back for LSTM | |
| # LSTM | |
| lstm_out, _ = self.lstm(x) | |
| # Get last output | |
| last_output = lstm_out[:, -1, :] | |
| # Prediction | |
| out = self.fc(last_output) | |
| return out | |
| class StockPredictor: | |
| def __init__(self): | |
| # Initialize sentiment model | |
| self.sentiment_model = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert") | |
| self.tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert") | |
| # Initialize prediction model | |
| self.prediction_model = StockPredictionModel( | |
| input_dim=6, # price, volume, sentiment, RSI, MACD, Signal | |
| hidden_dim=64, | |
| num_layers=2, | |
| output_dim=1 | |
| ) | |
| def calculate_technical_indicators(self, df): | |
| # Calculate RSI | |
| delta = df['Close'].diff() | |
| gain = (delta.where(delta > 0, 0)).rolling(window=14).mean() | |
| loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean() | |
| rs = gain / loss | |
| df['RSI'] = 100 - (100 / (1 + rs)) | |
| # Calculate MACD | |
| exp1 = df['Close'].ewm(span=12, adjust=False).mean() | |
| exp2 = df['Close'].ewm(span=26, adjust=False).mean() | |
| df['MACD'] = exp1 - exp2 | |
| df['Signal'] = df['MACD'].ewm(span=9, adjust=False).mean() | |
| return df | |
| def get_stock_data(self, ticker, period='1y'): | |
| try: | |
| stock = yf.Ticker(ticker) | |
| df = stock.history(period=period) | |
| df = self.calculate_technical_indicators(df) | |
| return df, None | |
| except Exception as e: | |
| return None, f"Error fetching data: {str(e)}" | |
| def analyze_sentiment(self, text): | |
| inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| outputs = self.sentiment_model(**inputs) | |
| probabilities = torch.softmax(outputs.logits, dim=1) | |
| return probabilities[0].tolist() | |
| def predict(self, ticker, news_text, prediction_days): | |
| # Get stock data | |
| df, error = self.get_stock_data(ticker) | |
| if error: | |
| return None, error | |
| # Analyze sentiment | |
| sentiment_scores = self.analyze_sentiment(news_text) | |
| sentiment_value = sentiment_scores[1] # Positive sentiment score | |
| # Prepare features | |
| features = torch.tensor(df[['Close', 'Volume', 'RSI', 'MACD', 'Signal']].values, dtype=torch.float32) | |
| sentiment_column = torch.full((len(features), 1), sentiment_value) | |
| features = torch.cat([features, sentiment_column], dim=1) | |
| # Make prediction | |
| with torch.no_grad(): | |
| predictions = [] | |
| current_input = features[-30:].unsqueeze(0) # Use last 30 days as input | |
| for _ in range(prediction_days): | |
| prediction = self.prediction_model(current_input) | |
| predictions.append(prediction.item()) | |
| # Update input for next prediction | |
| new_row = torch.cat([ | |
| torch.tensor([[prediction.item(), | |
| current_input[0, -1, 1], # Volume | |
| current_input[0, -1, 2], # RSI | |
| current_input[0, -1, 3], # MACD | |
| current_input[0, -1, 4], # Signal | |
| sentiment_value]])], dim=0) | |
| current_input = torch.cat([current_input[:, 1:, :], new_row.unsqueeze(0)], dim=1) | |
| return predictions, None | |
| def create_prediction_plot(historical_data, predictions, ticker): | |
| # Create dates for predictions | |
| last_date = historical_data.index[-1] | |
| future_dates = [last_date + timedelta(days=i+1) for i in range(len(predictions))] | |
| # Create plot | |
| fig = go.Figure() | |
| # Add historical data | |
| fig.add_trace(go.Scatter( | |
| x=historical_data.index, | |
| y=historical_data['Close'], | |
| name='Historical', | |
| line=dict(color='blue') | |
| )) | |
| # Add predictions | |
| fig.add_trace(go.Scatter( | |
| x=future_dates, | |
| y=predictions, | |
| name='Prediction', | |
| line=dict(color='red', dash='dash') | |
| )) | |
| fig.update_layout( | |
| title=f'{ticker} Stock Price Prediction', | |
| xaxis_title='Date', | |
| yaxis_title='Price', | |
| hovermode='x' | |
| ) | |
| return fig | |
| def predict_stock(ticker, news_text, prediction_days): | |
| predictor = StockPredictor() | |
| # Get predictions | |
| predictions, error = predictor.predict(ticker, news_text, prediction_days) | |
| if error: | |
| return f"Error: {error}", None | |
| # Get historical data for plotting | |
| historical_data, error = predictor.get_stock_data(ticker) | |
| if error: | |
| return f"Error: {error}", None | |
| # Create plot | |
| plot = create_prediction_plot(historical_data, predictions, ticker) | |
| return "Prediction successful!", plot | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_stock, | |
| inputs=[ | |
| gr.Textbox(label="Stock Ticker (e.g., AAPL)"), | |
| gr.Textbox(label="Recent News or Analysis", lines=3), | |
| gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days", value=7) | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Status"), | |
| gr.Plot(label="Prediction Plot") | |
| ], | |
| title="Stock Price Prediction with Sentiment Analysis", | |
| description="Enter a stock ticker, recent news, and prediction period to get stock price forecasts.", | |
| theme="default" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |