AI-Trading / app.py
shaheerawan3's picture
Create app.py
07f3be1 verified
# app.py
import gradio as gr
import torch
import pandas as pd
import numpy as np
import yfinance as yf
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datetime import datetime, timedelta
import plotly.graph_objects as go
from torch import nn
import warnings
warnings.filterwarnings('ignore')
# Custom CNN-LSTM Model
class StockPredictionModel(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
super(StockPredictionModel, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
# CNN layers
self.conv1 = nn.Conv1d(input_dim, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
# LSTM layers
self.lstm = nn.LSTM(64, hidden_dim, num_layers, batch_first=True)
# Fully connected layer
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# CNN
x = x.permute(0, 2, 1) # Reshape for CNN
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = x.permute(0, 2, 1) # Reshape back for LSTM
# LSTM
lstm_out, _ = self.lstm(x)
# Get last output
last_output = lstm_out[:, -1, :]
# Prediction
out = self.fc(last_output)
return out
class StockPredictor:
def __init__(self):
# Initialize sentiment model
self.sentiment_model = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert")
self.tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert")
# Initialize prediction model
self.prediction_model = StockPredictionModel(
input_dim=6, # price, volume, sentiment, RSI, MACD, Signal
hidden_dim=64,
num_layers=2,
output_dim=1
)
def calculate_technical_indicators(self, df):
# Calculate RSI
delta = df['Close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
rs = gain / loss
df['RSI'] = 100 - (100 / (1 + rs))
# Calculate MACD
exp1 = df['Close'].ewm(span=12, adjust=False).mean()
exp2 = df['Close'].ewm(span=26, adjust=False).mean()
df['MACD'] = exp1 - exp2
df['Signal'] = df['MACD'].ewm(span=9, adjust=False).mean()
return df
def get_stock_data(self, ticker, period='1y'):
try:
stock = yf.Ticker(ticker)
df = stock.history(period=period)
df = self.calculate_technical_indicators(df)
return df, None
except Exception as e:
return None, f"Error fetching data: {str(e)}"
def analyze_sentiment(self, text):
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
outputs = self.sentiment_model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=1)
return probabilities[0].tolist()
def predict(self, ticker, news_text, prediction_days):
# Get stock data
df, error = self.get_stock_data(ticker)
if error:
return None, error
# Analyze sentiment
sentiment_scores = self.analyze_sentiment(news_text)
sentiment_value = sentiment_scores[1] # Positive sentiment score
# Prepare features
features = torch.tensor(df[['Close', 'Volume', 'RSI', 'MACD', 'Signal']].values, dtype=torch.float32)
sentiment_column = torch.full((len(features), 1), sentiment_value)
features = torch.cat([features, sentiment_column], dim=1)
# Make prediction
with torch.no_grad():
predictions = []
current_input = features[-30:].unsqueeze(0) # Use last 30 days as input
for _ in range(prediction_days):
prediction = self.prediction_model(current_input)
predictions.append(prediction.item())
# Update input for next prediction
new_row = torch.cat([
torch.tensor([[prediction.item(),
current_input[0, -1, 1], # Volume
current_input[0, -1, 2], # RSI
current_input[0, -1, 3], # MACD
current_input[0, -1, 4], # Signal
sentiment_value]])], dim=0)
current_input = torch.cat([current_input[:, 1:, :], new_row.unsqueeze(0)], dim=1)
return predictions, None
def create_prediction_plot(historical_data, predictions, ticker):
# Create dates for predictions
last_date = historical_data.index[-1]
future_dates = [last_date + timedelta(days=i+1) for i in range(len(predictions))]
# Create plot
fig = go.Figure()
# Add historical data
fig.add_trace(go.Scatter(
x=historical_data.index,
y=historical_data['Close'],
name='Historical',
line=dict(color='blue')
))
# Add predictions
fig.add_trace(go.Scatter(
x=future_dates,
y=predictions,
name='Prediction',
line=dict(color='red', dash='dash')
))
fig.update_layout(
title=f'{ticker} Stock Price Prediction',
xaxis_title='Date',
yaxis_title='Price',
hovermode='x'
)
return fig
def predict_stock(ticker, news_text, prediction_days):
predictor = StockPredictor()
# Get predictions
predictions, error = predictor.predict(ticker, news_text, prediction_days)
if error:
return f"Error: {error}", None
# Get historical data for plotting
historical_data, error = predictor.get_stock_data(ticker)
if error:
return f"Error: {error}", None
# Create plot
plot = create_prediction_plot(historical_data, predictions, ticker)
return "Prediction successful!", plot
# Create Gradio interface
iface = gr.Interface(
fn=predict_stock,
inputs=[
gr.Textbox(label="Stock Ticker (e.g., AAPL)"),
gr.Textbox(label="Recent News or Analysis", lines=3),
gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days", value=7)
],
outputs=[
gr.Textbox(label="Status"),
gr.Plot(label="Prediction Plot")
],
title="Stock Price Prediction with Sentiment Analysis",
description="Enter a stock ticker, recent news, and prediction period to get stock price forecasts.",
theme="default"
)
if __name__ == "__main__":
iface.launch()