dure-waseem's picture
initial code
6e65197
import gradio as gr
import sys
import os
import traceback
from crew import PredictingStock
import json
import re
import tempfile
import subprocess
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from typing import Dict, Any
from datetime import datetime, timedelta
import yfinance as yf
class StockDataProcessor:
"""Handles stock data processing and visualization"""
def __init__(self):
self.data = None
self.df = None
def get_stock_data(self, ticker: str) -> bool:
"""Get stock data directly using yfinance (no file storage)"""
try:
# Calculate dates for last 6 months
end_date = datetime.now()
start_date = end_date - timedelta(days=180) # Approximately 6 months
# Get the data
ticker_obj = yf.Ticker(ticker)
hist = ticker_obj.history(start=start_date.strftime("%Y-%m-%d"),
end=end_date.strftime("%Y-%m-%d"))
if hist.empty:
return False
# Convert to the expected format
self.data = hist.to_dict(orient='index')
# Convert to DataFrame
self.df = pd.DataFrame.from_dict(self.data, orient='index')
self.df.index = pd.to_datetime(self.df.index)
self.df = self.df.sort_index()
# Rename columns for easier access
self.df.columns = ['open', 'high', 'low', 'close', 'volume', 'dividends', 'stock_splits']
# Calculate technical indicators
self._calculate_technical_indicators()
return True
except Exception as e:
return False
def create_temp_json_file(self, ticker: str) -> str:
"""Create temporary JSON file for AI analysis"""
try:
if self.data is None:
return None
# Create temporary file
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False)
# Convert data to JSON format expected by AI
json_data = {}
for date_key, values in self.data.items():
# Convert datetime to string if needed
date_str = date_key.strftime('%Y-%m-%d') if hasattr(date_key, 'strftime') else str(date_key)
json_data[date_str] = [
values['Open'],
values['High'],
values['Low'],
values['Close'],
values['Volume'],
values['Dividends'],
values['Stock Splits']
]
json.dump(json_data, temp_file)
temp_file.close()
return temp_file.name
except Exception as e:
return None
def _calculate_technical_indicators(self):
"""Calculate technical indicators"""
# Daily returns
self.df['daily_return'] = self.df['close'].pct_change()
# Moving averages
self.df['ma_20'] = self.df['close'].rolling(window=20).mean()
self.df['ma_50'] = self.df['close'].rolling(window=50).mean()
# RSI
self.df['rsi'] = self._calculate_rsi(self.df['close'])
# Bollinger Bands
self.df['bb_middle'] = self.df['ma_20']
bb_std = self.df['close'].rolling(window=20).std()
self.df['bb_upper'] = self.df['bb_middle'] + (bb_std * 2)
self.df['bb_lower'] = self.df['bb_middle'] - (bb_std * 2)
# Volume moving average
self.df['volume_ma'] = self.df['volume'].rolling(window=20).mean()
def _calculate_rsi(self, prices, window=14):
"""Calculate RSI indicator"""
delta = prices.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=window).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=window).mean()
rs = gain / loss
return 100 - (100 / (1 + rs))
def create_interactive_chart(self, ticker: str):
"""Create interactive Plotly chart with only stock price and technical analysis"""
if self.df is None or self.df.empty:
return None
# Create a single subplot for the stock price with technical analysis
fig = make_subplots(
rows=1, cols=1,
shared_xaxes=True,
subplot_titles=(f'{ticker} - Stock Price with Technical Analysis',)
)
# Price chart with technical indicators
fig.add_trace(
go.Scatter(x=self.df.index, y=self.df['close'],
name='Close Price', line=dict(color='#2E86C1', width=2)),
row=1, col=1
)
fig.add_trace(
go.Scatter(x=self.df.index, y=self.df['ma_20'],
name='20-day MA', line=dict(color='orange', width=1)),
row=1, col=1
)
fig.add_trace(
go.Scatter(x=self.df.index, y=self.df['ma_50'],
name='50-day MA', line=dict(color='red', width=1)),
row=1, col=1
)
# Bollinger Bands
fig.add_trace(
go.Scatter(x=self.df.index, y=self.df['bb_upper'],
line=dict(color='gray', width=1, dash='dash'),
name='BB Upper', showlegend=False),
row=1, col=1
)
fig.add_trace(
go.Scatter(x=self.df.index, y=self.df['bb_lower'],
line=dict(color='gray', width=1, dash='dash'),
fill='tonexty', fillcolor='rgba(128,128,128,0.1)',
name='Bollinger Bands'),
row=1, col=1
)
# Update layout
fig.update_layout(
height=600,
title_text=f"Stock Price with Technical Analysis for {ticker}",
showlegend=True,
xaxis_rangeslider_visible=True
)
# Update y-axes labels
fig.update_yaxes(title_text="Price ($)", row=1, col=1)
fig.update_xaxes(title_text="Date", row=1, col=1)
return fig
def create_default_chart(self):
"""Create a default placeholder chart"""
fig = go.Figure()
fig.add_trace(go.Scatter(
x=[0, 1, 2, 3, 4],
y=[100, 110, 105, 115, 120],
mode='lines+markers',
name='Sample Stock Data',
line=dict(color='#2E86C1', width=2)
))
fig.update_layout(
title="📈 Example Stock Chart Preview (Note: This is not the actual graph.)",
xaxis_title="Time Period",
yaxis_title="Price ($)",
height=400,
template="plotly_white"
)
return fig
def get_statistics(self, ticker: str) -> str:
"""Generate key statistics"""
if self.df is None or self.df.empty:
return "No data available"
try:
current_price = self.df['close'].iloc[-1]
start_price = self.df['close'].iloc[0]
stats = f"""
📊 **Key Statistics for {ticker}**
**Price Performance:**
• Current Price: ${current_price:.2f}
**Trading Activity:**
• Data Period: {self.df.index[0].strftime('%Y-%m-%d')} to {self.df.index[-1].strftime('%Y-%m-%d')}
• Total Trading Days: {len(self.df)}
"""
return stats.strip()
except Exception as e:
return f"Error calculating statistics: {e}"
def get_default_statistics(self):
"""Get default statistics for initial display"""
return """
📊 **Stock Statistics Preview**
Welcome to AI Stock Analysis!
**What you'll get:**
• Real-time stock price data
• Technical analysis indicators
• AI-powered investment recommendations
• Risk assessment and insights
📈 **Enter a stock ticker below to begin analysis**
"""
class SecureAPIHandler:
"""Handles API keys securely using HuggingFace secrets"""
def __init__(self):
self.session_keys = {}
# Check for Finnhub API key in HuggingFace secrets
self.finnhub_available = bool(os.getenv("FINNHUB_API_KEY"))
def validate_anthropic_key(self, key_value: str) -> tuple[bool, str]:
"""Validate Anthropic API key format"""
if not key_value or not key_value.strip():
return False, "Anthropic API key is required"
key_value = key_value.strip()
if not key_value.startswith("sk-ant-"):
return False, "Anthropic API key should start with 'sk-ant-'"
if len(key_value) < 20:
return False, "Anthropic API key appears too short"
return True, "Valid"
def check_finnhub_secret(self) -> tuple[bool, str]:
"""Check if Finnhub key is available in HuggingFace secrets"""
finnhub_key = os.getenv("FINNHUB_API_KEY")
if not finnhub_key:
return False, "Finnhub API key not found in HuggingFace secrets. Please add FINNHUB_API_KEY to your space secrets."
if len(finnhub_key.strip()) < 10:
return False, "Finnhub API key in secrets appears invalid"
return True, "Finnhub API key loaded from HuggingFace secrets"
def set_anthropic_key(self, anthropic_key: str) -> tuple[bool, str]:
"""Securely set Anthropic API key for the session"""
try:
is_valid, message = self.validate_anthropic_key(anthropic_key)
if not is_valid:
return False, f"Validation Error: {message}"
finnhub_valid, finnhub_message = self.check_finnhub_secret()
if not finnhub_valid:
return False, f"Secret Configuration Error: {finnhub_message}"
os.environ["ANTHROPIC_API_KEY"] = anthropic_key.strip()
self.session_keys["ANTHROPIC_API_KEY"] = "***" + anthropic_key[-4:]
self.session_keys["FINNHUB_API_KEY"] = "***" + os.getenv("FINNHUB_API_KEY")[-4:]
return True, "API keys validated and configured successfully"
except Exception as e:
return False, f"Error configuring API keys: {str(e)}"
def clear_session_keys(self):
"""Clear Anthropic API key from environment"""
if "ANTHROPIC_API_KEY" in os.environ:
del os.environ["ANTHROPIC_API_KEY"]
self.session_keys.clear()
# Global instances
api_handler = SecureAPIHandler()
stock_processor = StockDataProcessor()
temp_files_to_cleanup = []
def cleanup_temp_files():
"""Clean up temporary files"""
global temp_files_to_cleanup
for temp_file in temp_files_to_cleanup:
try:
if os.path.exists(temp_file):
os.unlink(temp_file)
except:
pass
temp_files_to_cleanup.clear()
def load_stock_data_and_chart(ticker_symbol: str = "AAPL", progress=gr.Progress()):
"""Load stock data and create chart/stats using yfinance directly"""
try:
if progress:
progress(0.1, desc="Fetching stock data...")
ticker = ticker_symbol.strip().upper()
# Get stock data directly
if not stock_processor.get_stock_data(ticker):
return f"❌ Error: Could not fetch data for {ticker}. Please check the ticker symbol.", stock_processor.get_default_statistics(), stock_processor.create_default_chart()
if progress:
progress(0.7, desc="Processing data and creating visualization...")
# Create visualization and stats
chart_fig = stock_processor.create_interactive_chart(ticker)
stats = stock_processor.get_statistics(ticker)
if progress:
progress(1.0, desc="Complete!")
success_msg = f"✅ Successfully loaded data for {ticker}"
return success_msg, stats, chart_fig
except Exception as e:
error_msg = f"❌ Error loading stock data: {str(e)}"
return error_msg, stock_processor.get_default_statistics(), stock_processor.create_default_chart()
def load_company_data(ticker: str):
"""Load data for a specific company from navigation bar"""
status, stats, chart = load_stock_data_and_chart(ticker)
return ticker, status, stats, chart
def run_ai_analysis(company_ticker: str, max_amount: str, anthropic_key: str, progress=gr.Progress()):
"""Run AI crew analysis using temporary files"""
global temp_files_to_cleanup
try:
progress(0.05, desc="Validating inputs...")
# Validate inputs
if not company_ticker or not company_ticker.strip():
return "❌ Error: Please enter a company ticker symbol"
if not max_amount or not max_amount.strip():
return "❌ Error: Please enter a maximum investment amount"
try:
float(max_amount)
except ValueError:
return "❌ Error: Maximum investment amount must be a valid number"
if not anthropic_key or not anthropic_key.strip():
return "❌ Error: Anthropic API key is required"
ticker = company_ticker.strip().upper()
progress(0.2, desc="Configuring API keys...")
# Set API keys securely
success, message = api_handler.set_anthropic_key(anthropic_key)
if not success:
return f"❌ {message}"
progress(0.3, desc="Fetching stock data for AI analysis...")
# Get fresh stock data for AI analysis
if not stock_processor.get_stock_data(ticker):
return f"❌ Error: Could not fetch stock data for {ticker}"
# Create temporary JSON file for AI analysis
temp_json_file = stock_processor.create_temp_json_file(ticker)
if not temp_json_file:
return "❌ Error: Could not create temporary data file for AI analysis"
# Add to cleanup list
temp_files_to_cleanup.append(temp_json_file)
# Set the JSON file path for the crew to use
os.environ["TEMP_JSON_FILE"] = temp_json_file
progress(0.4, desc="Initializing AI agents...")
# Prepare inputs for CrewAI
inputs = {
"company_name": ticker,
"max_amount": max_amount
}
progress(0.6, desc="AI agents are running now... This may take 2-5 minutes")
# Run CrewAI analysis
crew_instance = PredictingStock()
result = crew_instance.crew().kickoff(inputs=inputs)
progress(1.0, desc="Analysis complete!")
# Clean up temporary files
cleanup_temp_files()
# Format results
analysis_result = f"""
🎯 COMPREHENSIVE STOCK ANALYSIS FOR: {ticker}
{'='*60}
{result.raw}
""".strip()
return analysis_result
except Exception as e:
# Clean up temporary files in case of error
cleanup_temp_files()
error_msg = f"""
❌ ERROR DURING AI ANALYSIS:
{str(e)}
TROUBLESHOOTING TIPS:
1. Verify your Anthropic API key is correct
2. Check that FINNHUB_API_KEY is set in your HuggingFace space secrets
3. Check your internet connection
4. Ensure the ticker symbol is valid
5. Try again in a few moments
Full error details:
{traceback.format_exc()}
""".strip()
return error_msg
def create_secure_interface():
"""Create the enhanced Gradio interface"""
css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.header {
text-align: center;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border-radius: 10px;
margin-bottom: 30px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.nav-bar {
background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%);
padding: 15px;
border-radius: 10px;
margin-bottom: 20px;
text-align: center;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
.nav-button {
margin: 5px !important;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
color: white !important;
font-weight: bold !important;
border-radius: 8px !important;
padding: 10px 20px !important;
transition: all 0.3s ease !important;
}
.nav-button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2) !important;
}
.stats-box {
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
color: white;
padding: 15px;
border-radius: 10px;
margin: 10px 0;
}
.secure-input {
border: 2px solid #28a745 !important;
}
.section-header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 10px;
border-radius: 8px;
margin: 15px 0 10px 0;
text-align: center;
}
"""
with gr.Blocks(css=css, title="AI Stock Prediction with Visualization") as interface:
# Header
gr.HTML("""
<div class="header">
<h1>🚀 AI Stock Prediction & Visualization System</h1>
<p>Complete stock analysis with interactive charts and AI-powered insights</p>
</div>
""")
# NAVIGATION BAR
gr.HTML('<div class="nav-bar"><h3>🏢 Quick Access - Popular Companies</h3></div>')
with gr.Row():
apple_btn = gr.Button("🍎 Apple (AAPL)", elem_classes=["nav-button"], scale=1)
google_btn = gr.Button("🔍 Google (GOOGL)", elem_classes=["nav-button"], scale=1)
microsoft_btn = gr.Button("💻 Microsoft (MSFT)", elem_classes=["nav-button"], scale=1)
tesla_btn = gr.Button("⚡ Tesla (TSLA)", elem_classes=["nav-button"], scale=1)
amazon_btn = gr.Button("📦 Amazon (AMZN)", elem_classes=["nav-button"], scale=1)
# CUSTOM COMPANY SECTION
gr.HTML('<div class="nav-bar"><h4>✨ Add Any Company</h4></div>')
with gr.Row():
with gr.Column(scale=3):
custom_ticker_input = gr.Textbox(
label="Enter Any Stock Ticker",
placeholder="e.g., NVDA, META, NFLX, etc.",
value="",
info="Type any valid stock ticker symbol",
elem_classes=["secure-input"]
)
with gr.Column(scale=1):
custom_load_btn = gr.Button(
"🚀 Load Custom Stock",
elem_classes=["nav-button"],
variant="primary",
size="lg"
)
# TOP SECTION: Statistics and Chart
gr.HTML('<div class="section-header"><h2>📊 Live Statistics & Interactive Chart</h2></div>')
# Load default stock data (AAPL) on startup
default_message, default_stats, default_chart = load_stock_data_and_chart("AAPL")
with gr.Row():
with gr.Column(scale=1):
stats_output = gr.Markdown(
value=default_stats,
elem_classes=["stats-box"]
)
with gr.Column(scale=2):
chart_output = gr.Plot(
value=default_chart,
label="Stock Price Chart",
show_label=True
)
# Status message for stock data loading
stock_status = gr.Markdown(value=default_message)
# MIDDLE SECTION: Input Configuration
gr.HTML('<div class="section-header"><h2>⚙️ Analysis Configuration</h2></div>')
with gr.Row():
with gr.Column(scale=1):
company_ticker = gr.Textbox(
label="Company Ticker Symbol",
placeholder="Selected company ticker will appear here",
value="AAPL", # Start with AAPL
info="Stock ticker symbol to analyze",
elem_classes=["secure-input"],
interactive=False
)
with gr.Column(scale=1):
max_amount = gr.Textbox(
label="Maximum Investment Amount (USD)",
placeholder="e.g., 1000, 5000.50",
value="",
info="The maximum amount you are willing to invest",
elem_classes=["secure-input"]
)
with gr.Column(scale=1):
anthropic_key = gr.Textbox(
label="Anthropic API Key (Claude)",
placeholder="sk-ant-api03-...",
type="password",
info="Get from console.anthropic.com - Required for AI analysis",
elem_classes=["secure-input"]
)
with gr.Row():
with gr.Column():
load_data_btn = gr.Button(
"📊 Load Stock Data",
variant="secondary",
size="lg",
scale=1
)
with gr.Column():
analyze_btn = gr.Button(
"🤖 Run AI Analysis",
variant="primary",
size="lg",
scale=2
)
# BOTTOM SECTION: AI Analysis Results
gr.HTML('<div class="section-header"><h2>🤖 AI Agent Analysis & Recommendations</h2></div>')
result_output = gr.TextArea(
label="AI Analysis Report",
lines=25,
max_lines=35,
interactive=False,
show_copy_button=True,
placeholder="🤖 AI analysis results and investment recommendations will appear here after running the analysis...\n\nThe AI agents will provide:\n• Comprehensive stock analysis\n• Risk assessment\n• Investment recommendations\n• Market insights\n• Technical analysis summary"
)
# Event handlers
apple_btn.click(
fn=lambda: load_company_data("AAPL"),
outputs=[company_ticker, stock_status, stats_output, chart_output]
)
google_btn.click(
fn=lambda: load_company_data("GOOGL"),
outputs=[company_ticker, stock_status, stats_output, chart_output]
)
microsoft_btn.click(
fn=lambda: load_company_data("MSFT"),
outputs=[company_ticker, stock_status, stats_output, chart_output]
)
tesla_btn.click(
fn=lambda: load_company_data("TSLA"),
outputs=[company_ticker, stock_status, stats_output, chart_output]
)
amazon_btn.click(
fn=lambda: load_company_data("AMZN"),
outputs=[company_ticker, stock_status, stats_output, chart_output]
)
custom_load_btn.click(
fn=lambda ticker: load_company_data(ticker),
inputs=[custom_ticker_input],
outputs=[company_ticker, stock_status, stats_output, chart_output]
)
load_data_btn.click(
fn=load_stock_data_and_chart,
inputs=[company_ticker],
outputs=[stock_status, stats_output, chart_output]
)
analyze_btn.click(
fn=run_ai_analysis,
inputs=[company_ticker, max_amount, anthropic_key],
outputs=[result_output]
)
# Footer
gr.HTML("""
<div style="text-align: center; padding: 20px; margin-top: 30px; border-top: 1px solid #dee2e6;">
<p>🔐 <strong>Secure AI Stock Analysis with Interactive Visualization</strong></p>
<p style="font-size: 12px; color: #6c757d;">
Features: Real-time data • Technical indicators • AI analysis • Interactive charts
</p>
</div>
""")
return interface
# Launch the application
if __name__ == "__main__":
interface = create_secure_interface()
interface.launch()