Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import yfinance as yf | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| from datetime import datetime, timedelta | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import chromadb | |
| import requests | |
| from bs4 import BeautifulSoup | |
| import warnings | |
| from typing import Dict, List, Tuple | |
| import feedparser | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import json | |
| import os | |
| warnings.filterwarnings('ignore') | |
| COMPANIES = { | |
| 'Apple (AAPL)': 'AAPL', | |
| 'Microsoft (MSFT)': 'MSFT', | |
| 'Amazon (AMZN)': 'AMZN', | |
| 'Google (GOOGL)': 'GOOGL', | |
| 'Meta (META)': 'META', | |
| 'Tesla (TSLA)': 'TSLA', | |
| 'NVIDIA (NVDA)': 'NVDA', | |
| 'JPMorgan Chase (JPM)': 'JPM', | |
| 'Johnson & Johnson (JNJ)': 'JNJ', | |
| 'Walmart (WMT)': 'WMT', | |
| 'Visa (V)': 'V', | |
| 'Mastercard (MA)': 'MA', | |
| 'Procter & Gamble (PG)': 'PG', | |
| 'UnitedHealth (UNH)': 'UNH', | |
| 'Home Depot (HD)': 'HD', | |
| 'Bank of America (BAC)': 'BAC', | |
| 'Coca-Cola (KO)': 'KO', | |
| 'Pfizer (PFE)': 'PFE', | |
| 'Disney (DIS)': 'DIS', | |
| 'Netflix (NFLX)': 'NFLX' | |
| } | |
| # Initialize models | |
| print("Initializing models...") | |
| api_token = os.getenv("TOKEN") | |
| llm = HuggingFaceEndpoint( | |
| repo_id="mistralai/Mistral-7B-Instruct-v0.2", | |
| huggingfacehub_api_token=api_token, | |
| temperature=0.7, | |
| max_new_tokens=1000 | |
| ) | |
| vader = SentimentIntensityAnalyzer() | |
| finbert = pipeline("sentiment-analysis", | |
| model="ProsusAI/finbert") | |
| print("Models initialized successfully!") | |
| class AgenticRAGFramework: | |
| """Main framework coordinating all agents""" | |
| def __init__(self): | |
| self.technical_agent = TechnicalAnalysisAgent() | |
| self.sentiment_agent = SentimentAnalysisAgent() | |
| self.llama_agent = LLMAgent() | |
| self.knowledge_base = chromadb.Client() | |
| def analyze(self, symbol: str, data: pd.DataFrame) -> Dict: | |
| """Perform comprehensive analysis""" | |
| technical_analysis = self.technical_agent.analyze(data) | |
| sentiment_analysis = self.sentiment_agent.analyze(symbol) | |
| llm_analysis = self.llama_agent.generate_analysis( | |
| technical_analysis, | |
| sentiment_analysis | |
| ) | |
| return { | |
| 'technical_analysis': technical_analysis, | |
| 'sentiment_analysis': sentiment_analysis, | |
| 'llm_analysis': llm_analysis | |
| } | |
| class NewsSource: | |
| """Base class for news sources""" | |
| def get_news(self, company: str) -> List[Dict]: | |
| raise NotImplementedError | |
| class FinvizNews(NewsSource): | |
| """Fetch news from FinViz""" | |
| def get_news(self, company: str) -> List[Dict]: | |
| try: | |
| ticker = company.split('(')[-1].replace(')', '') | |
| url = f"https://finviz.com/quote.ashx?t={ticker}" | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
| } | |
| response = requests.get(url, headers=headers) | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| news_table = soup.find('table', {'class': 'news-table'}) | |
| if not news_table: | |
| return [] | |
| news_list = [] | |
| for row in news_table.find_all('tr')[:5]: | |
| cols = row.find_all('td') | |
| if len(cols) >= 2: | |
| date = cols[0].text.strip() | |
| title = cols[1].a.text.strip() | |
| link = cols[1].a['href'] | |
| news_list.append({ | |
| 'title': title, | |
| 'description': title, | |
| 'date': date, | |
| 'source': 'FinViz', | |
| 'url': link | |
| }) | |
| return news_list | |
| except Exception as e: | |
| print(f"FinViz Error: {str(e)}") | |
| return [] | |
| class MarketWatchNews(NewsSource): | |
| """Fetch news from MarketWatch""" | |
| def get_news(self, company: str) -> List[Dict]: | |
| try: | |
| ticker = company.split('(')[-1].replace(')', '') | |
| url = f"https://www.marketwatch.com/investing/stock/{ticker}" | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
| } | |
| response = requests.get(url, headers=headers) | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| news_elements = soup.find_all('div', {'class': 'article__content'}) | |
| news_list = [] | |
| for element in news_elements[:5]: | |
| title_elem = element.find('a', {'class': 'link'}) | |
| if title_elem: | |
| title = title_elem.text.strip() | |
| link = title_elem['href'] | |
| date_elem = element.find('span', {'class': 'article__timestamp'}) | |
| date = date_elem.text if date_elem else 'Recent' | |
| news_list.append({ | |
| 'title': title, | |
| 'description': title, | |
| 'date': date, | |
| 'source': 'MarketWatch', | |
| 'url': link | |
| }) | |
| return news_list | |
| except Exception as e: | |
| print(f"MarketWatch Error: {str(e)}") | |
| return [] | |
| class YahooRSSNews(NewsSource): | |
| """Fetch news from Yahoo Finance RSS feed""" | |
| def get_news(self, company: str) -> List[Dict]: | |
| try: | |
| ticker = company.split('(')[-1].replace(')', '') | |
| url = f"https://feeds.finance.yahoo.com/rss/2.0/headline?s={ticker}®ion=US&lang=en-US" | |
| feed = feedparser.parse(url) | |
| news_list = [] | |
| for entry in feed.entries[:5]: | |
| news_list.append({ | |
| 'title': entry.title, | |
| 'description': entry.description, | |
| 'date': entry.published, | |
| 'source': 'Yahoo Finance', | |
| 'url': entry.link | |
| }) | |
| return news_list | |
| except Exception as e: | |
| print(f"Yahoo RSS Error: {str(e)}") | |
| return [] | |
| class TechnicalAnalysisAgent: | |
| """Agent for technical analysis""" | |
| def __init__(self): | |
| self.required_periods = { | |
| 'sma': [20, 50, 200], | |
| 'rsi': 14, | |
| 'volatility': 20, | |
| 'macd': [12, 26, 9] | |
| } | |
| def analyze(self, data: pd.DataFrame) -> Dict: | |
| df = data.copy() | |
| close_col = ('Close', df.columns.get_level_values(1)[0]) | |
| # Calculate metrics | |
| df['Returns'] = df[close_col].pct_change() | |
| # SMAs | |
| for period in self.required_periods['sma']: | |
| df[f'SMA_{period}'] = df[close_col].rolling(window=period).mean() | |
| # RSI | |
| delta = df[close_col].diff() | |
| gain = delta.where(delta > 0, 0).rolling(window=14).mean() | |
| loss = -delta.where(delta < 0, 0).rolling(window=14).mean() | |
| rs = gain / loss | |
| df['RSI'] = 100 - (100 / (1 + rs)) | |
| # MACD | |
| exp1 = df[close_col].ewm(span=12, adjust=False).mean() | |
| exp2 = df[close_col].ewm(span=26, adjust=False).mean() | |
| df['MACD'] = exp1 - exp2 | |
| df['Signal_Line'] = df['MACD'].ewm(span=9, adjust=False).mean() | |
| # Bollinger Bands | |
| df['BB_middle'] = df[close_col].rolling(window=20).mean() | |
| rolling_std = df[close_col].rolling(window=20).std() | |
| df['BB_upper'] = df['BB_middle'] + (2 * rolling_std) | |
| df['BB_lower'] = df['BB_middle'] - (2 * rolling_std) | |
| return { | |
| 'processed_data': df, | |
| 'current_signals': self._generate_signals(df, close_col) | |
| } | |
| def _generate_signals(self, df: pd.DataFrame, close_col) -> Dict: | |
| if df.empty: | |
| return { | |
| 'trend': 'Unknown', | |
| 'rsi_signal': 'Unknown', | |
| 'macd_signal': 'Unknown', | |
| 'bb_position': 'Unknown' | |
| } | |
| current = df.iloc[-1] | |
| trend = 'Bullish' if float(current['SMA_20']) > float(current['SMA_50']) else 'Bearish' | |
| rsi_value = float(current['RSI']) | |
| if rsi_value > 70: | |
| rsi_signal = 'Overbought' | |
| elif rsi_value < 30: | |
| rsi_signal = 'Oversold' | |
| else: | |
| rsi_signal = 'Neutral' | |
| macd_signal = 'Buy' if float(current['MACD']) > float(current['Signal_Line']) else 'Sell' | |
| close_value = float(current[close_col]) | |
| bb_upper = float(current['BB_upper']) | |
| bb_lower = float(current['BB_lower']) | |
| if close_value > bb_upper: | |
| bb_position = 'Above Upper Band' | |
| elif close_value < bb_lower: | |
| bb_position = 'Below Lower Band' | |
| else: | |
| bb_position = 'Within Bands' | |
| return { | |
| 'trend': trend, | |
| 'rsi_signal': rsi_signal, | |
| 'macd_signal': macd_signal, | |
| 'bb_position': bb_position | |
| } | |
| class SentimentAnalysisAgent: | |
| """Agent for sentiment analysis""" | |
| def __init__(self): | |
| self.news_sources = [ | |
| FinvizNews(), | |
| MarketWatchNews(), | |
| YahooRSSNews() | |
| ] | |
| def analyze(self, symbol: str) -> Dict: | |
| all_news = [] | |
| for source in self.news_sources: | |
| news_items = source.get_news(symbol) | |
| all_news.extend(news_items) | |
| vader_scores = [] | |
| finbert_scores = [] | |
| for article in all_news: | |
| vader_scores.append(vader.polarity_scores(article['title'])) | |
| finbert_scores.append( | |
| finbert(article['title'][:512])[0] | |
| ) | |
| return { | |
| 'articles': all_news, | |
| 'vader_scores': vader_scores, | |
| 'finbert_scores': finbert_scores, | |
| 'aggregated': self._aggregate_sentiment(vader_scores, finbert_scores) | |
| } | |
| def _aggregate_sentiment(self, vader_scores: List[Dict], | |
| finbert_scores: List[Dict]) -> Dict: | |
| if not vader_scores or not finbert_scores: | |
| return { | |
| 'sentiment': 'Neutral', | |
| 'confidence': 0, | |
| 'vader_sentiment': 0, | |
| 'finbert_sentiment': 0 | |
| } | |
| avg_vader = np.mean([score['compound'] for score in vader_scores]) | |
| avg_finbert = np.mean([ | |
| 1 if score['label'] == 'positive' else -1 | |
| for score in finbert_scores | |
| ]) | |
| combined_score = (avg_vader + avg_finbert) / 2 | |
| return { | |
| 'sentiment': 'Bullish' if combined_score > 0.1 else 'Bearish' if combined_score < -0.1 else 'Neutral', | |
| 'confidence': abs(combined_score), | |
| 'vader_sentiment': avg_vader, | |
| 'finbert_sentiment': avg_finbert | |
| } | |
| class LLMAgent: | |
| """Agent for LLM-based analysis using HuggingFace API""" | |
| def __init__(self): | |
| self.llm = llm | |
| def generate_analysis(self, technical_data: Dict, sentiment_data: Dict) -> str: | |
| prompt = self._create_prompt(technical_data, sentiment_data) | |
| response = self.llm.invoke(prompt) | |
| return response | |
| def _create_prompt(self, technical_data: Dict, sentiment_data: Dict) -> str: | |
| return f"""Based on technical and sentiment indicators: | |
| Technical Signals: | |
| - Trend: {technical_data['current_signals']['trend']} | |
| - RSI: {technical_data['current_signals']['rsi_signal']} | |
| - MACD: {technical_data['current_signals']['macd_signal']} | |
| - BB Position: {technical_data['current_signals']['bb_position']} | |
| - Sentiment: {sentiment_data['aggregated']['sentiment']} (Confidence: {sentiment_data['aggregated']['confidence']:.2f}) | |
| Provide: | |
| 1. Current Trend Analysis | |
| 2. Key Risk Factors | |
| 3. Trading Recommendations | |
| 4. Price Targets | |
| 5. Near-term Outlook (1-2 weeks) | |
| Note: return only required information and nothing unnecessary""" | |
| # class ChatbotRouter: | |
| # """Routes chatbot queries to appropriate data sources and generates responses""" | |
| # def __init__(self): | |
| # self.llm = llm | |
| # self.encoder = SentenceTransformer('all-MiniLM-L6-v2') | |
| # self.faiss_index = None | |
| # self.company_data = {} | |
| # self.news_sources = [ | |
| # FinvizNews(), | |
| # MarketWatchNews(), | |
| # YahooRSSNews() | |
| # ] | |
| # self.load_faiss_index() | |
| # def route_and_respond(self, query: str, company: str) -> str: | |
| # query_type = self._classify_query(query.lower()) | |
| # route_message = f"\n[Taking {query_type.upper()} route]\n\n" | |
| # if query_type == "company_info": | |
| # context = self._get_company_context(query, company) | |
| # elif query_type == "news": | |
| # context = self._get_news_context(company) | |
| # elif query_type == "price": | |
| # context = self._get_price_context(company) | |
| # else: | |
| # return route_message + "I'm not sure how to handle this query. Please ask about company information, news, or price data." | |
| # prompt = self._create_prompt(query, context, query_type) | |
| # response = self.llm.invoke(prompt) | |
| # return route_message + response | |
| class ChatbotRouter: | |
| """Routes chatbot queries to appropriate data sources and generates responses""" | |
| def __init__(self): | |
| self.llm = llm | |
| self.encoder = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.faiss_index = None | |
| self.company_data = {} | |
| self.news_sources = [ | |
| FinvizNews(), | |
| MarketWatchNews(), | |
| YahooRSSNews() | |
| ] | |
| self.load_faiss_index() | |
| def load_faiss_index(self): | |
| try: | |
| self.faiss_index = faiss.read_index("company_profiles.index") | |
| for file in os.listdir('company_data'): | |
| with open(f'company_data/{file}', 'r') as f: | |
| company_name = file.replace('.txt', '') | |
| self.company_data[company_name] = json.load(f) | |
| except Exception as e: | |
| print(f"Error loading FAISS index: {e}") | |
| def route_and_respond(self, query: str, company: str) -> str: | |
| query_type = self._classify_query(query.lower()) | |
| route_message = f"\n[Taking {query_type.upper()} route]\n\n" | |
| if query_type == "company_info": | |
| context = self._get_company_context(query, company) | |
| elif query_type == "news": | |
| context = self._get_news_context(company) | |
| elif query_type == "price": | |
| context = self._get_price_context(company) | |
| else: | |
| return route_message + "I'm not sure how to handle this query. Please ask about company information, news, or price data." | |
| prompt = self._create_prompt(query, context, query_type) | |
| response = self.llm.invoke(prompt) | |
| return route_message + response | |
| def _classify_query(self, query: str) -> str: | |
| """Classify query type""" | |
| if any(word in query for word in ["profile", "about", "information", "details", "what", "who", "describe"]): | |
| return "company_info" | |
| elif any(word in query for word in ["news", "latest", "recent", "announcement", "update"]): | |
| return "news" | |
| elif any(word in query for word in ["price", "stock", "value", "market", "trading", "cost"]): | |
| return "price" | |
| return "unknown" | |
| def _get_company_context(self, query: str, company: str) -> str: | |
| """Get relevant company information using FAISS""" | |
| try: | |
| query_vector = self.encoder.encode([query]) | |
| D, I = self.faiss_index.search(query_vector, 1) | |
| company_name = company.split(" (")[0] | |
| company_info = self.company_data.get(company_name, {}) | |
| print(company_info) | |
| return company_info | |
| except Exception as e: | |
| return f"Error retrieving company information: {str(e)}" | |
| def _get_news_context(self, company: str) -> str: | |
| """Get news from multiple sources""" | |
| all_news = [] | |
| for source in self.news_sources: | |
| news_items = source.get_news(company) | |
| all_news.extend(news_items) | |
| seen_titles = set() | |
| unique_news = [] | |
| for news in all_news: | |
| if news['title'] not in seen_titles: | |
| seen_titles.add(news['title']) | |
| unique_news.append(news) | |
| if not unique_news: | |
| return "No recent news found." | |
| news_context = "Recent news articles:\n\n" | |
| for news in unique_news[:5]: | |
| news_context += f"Source: {news['source']}\n" | |
| news_context += f"Title: {news['title']}\n" | |
| if news['description']: | |
| news_context += f"Description: {news['description']}\n" | |
| news_context += f"Date: {news['date']}\n\n" | |
| return news_context | |
| def _get_price_context(self, company: str) -> str: | |
| """Get current price information""" | |
| try: | |
| ticker = company.split('(')[-1].replace(')', '') | |
| stock = yf.Ticker(ticker) | |
| info = stock.info | |
| return f"""Current Stock Information: | |
| Price: ${info.get('currentPrice', 'N/A')} | |
| Day Range: ${info.get('dayLow', 'N/A')} - ${info.get('dayHigh', 'N/A')} | |
| 52 Week Range: ${info.get('fiftyTwoWeekLow', 'N/A')} - ${info.get('fiftyTwoWeekHigh', 'N/A')} | |
| Market Cap: ${info.get('marketCap', 'N/A'):,} | |
| Volume: {info.get('volume', 'N/A'):,} | |
| P/E Ratio: {info.get('trailingPE', 'N/A')} | |
| Dividend Yield: {info.get('dividendYield', 'N/A')}%""" | |
| except Exception as e: | |
| return f"Error fetching price data: {str(e)}" | |
| def _create_prompt(self, query: str, context: str, query_type: str) -> str: | |
| """Create prompt for LLM""" | |
| if query_type == "news": | |
| return f"""Based on the following news articles, please provide a summary addressing the query. | |
| Context: | |
| {context} | |
| Query: {query} | |
| Please analyze the news and provide: | |
| 1. Key points from the recent articles | |
| 2. Any significant developments or trends | |
| 3. Potential impact on the company | |
| 4. Overall sentiment (positive/negative/neutral) | |
| Response should be clear, concise, and focused on the most relevant information.""" | |
| else: | |
| return f"""Based on the following {query_type} context, please answer the question. | |
| Context: | |
| {context} | |
| Question: {query} | |
| Please provide a clear and concise answer based on the given context.""" | |
| def _generate_response(self, prompt: str) -> str: | |
| """Generate response using LLM""" | |
| inputs = self.llm_agent.tokenizer(prompt, return_tensors="pt").to(self.llm_agent.model.device) | |
| outputs = self.llm_agent.model.generate( | |
| inputs["input_ids"], | |
| max_new_tokens=200, | |
| temperature=0.7, | |
| num_return_sequences=1 | |
| ) | |
| # Decode and remove the prompt part from the output | |
| response = self.llm_agent.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response_only = response.replace(prompt, "").strip() | |
| print(response) | |
| return response_only | |
| def analyze_stock(company: str, lookback_days: int = 180) -> Tuple[str, go.Figure, go.Figure]: | |
| """Main analysis function""" | |
| try: | |
| symbol = COMPANIES[company] | |
| end_date = datetime.now() | |
| start_date = end_date - timedelta(days=lookback_days) | |
| data = yf.download(symbol, start=start_date, end=end_date) | |
| if len(data) == 0: | |
| return "No data available.", None, None | |
| framework = AgenticRAGFramework() | |
| analysis = framework.analyze(symbol, data) | |
| plots = create_plots(analysis) | |
| return analysis['llm_analysis'], plots[0], plots[1] | |
| except Exception as e: | |
| return f"Error analyzing stock: {str(e)}", None, None | |
| def create_plots(analysis: Dict) -> List[go.Figure]: | |
| """Create analysis plots""" | |
| data = analysis['technical_analysis']['processed_data'] | |
| # Price and Volume Plot | |
| fig1 = make_subplots( | |
| rows=2, cols=1, | |
| shared_xaxes=True, | |
| vertical_spacing=0.03, | |
| subplot_titles=('Price Analysis', 'Volume'), | |
| row_heights=[0.7, 0.3] | |
| ) | |
| close_col = ('Close', data.columns.get_level_values(1)[0]) | |
| open_col = ('Open', data.columns.get_level_values(1)[0]) | |
| volume_col = ('Volume', data.columns.get_level_values(1)[0]) | |
| fig1.add_trace( | |
| go.Scatter(x=data.index, y=data[close_col], name='Price', | |
| line=dict(color='blue', width=2)), | |
| row=1, col=1 | |
| ) | |
| fig1.add_trace( | |
| go.Scatter(x=data.index, y=data['SMA_20'], name='SMA20', | |
| line=dict(color='orange', width=1.5)), | |
| row=1, col=1 | |
| ) | |
| fig1.add_trace( | |
| go.Scatter(x=data.index, y=data['SMA_50'], name='SMA50', | |
| line=dict(color='red', width=1.5)), | |
| row=1, col=1 | |
| ) | |
| colors = ['red' if float(row[close_col]) < float(row[open_col]) else 'green' | |
| for idx, row in data.iterrows()] | |
| fig1.add_trace( | |
| go.Bar(x=data.index, y=data[volume_col], marker_color=colors, name='Volume'), | |
| row=2, col=1 | |
| ) | |
| fig1.update_layout( | |
| height=400, | |
| showlegend=True, | |
| xaxis_rangeslider_visible=False, | |
| plot_bgcolor='white', | |
| paper_bgcolor='white' | |
| ) | |
| # Technical Indicators Plot | |
| fig2 = make_subplots( | |
| rows=3, cols=1, | |
| shared_xaxes=True, | |
| subplot_titles=('RSI', 'MACD', 'Bollinger Bands'), | |
| row_heights=[0.33, 0.33, 0.34], | |
| vertical_spacing=0.03 | |
| ) | |
| # RSI | |
| fig2.add_trace( | |
| go.Scatter(x=data.index, y=data['RSI'], name='RSI', | |
| line=dict(color='purple', width=1.5)), | |
| row=1, col=1 | |
| ) | |
| fig2.add_hline(y=70, line_dash="dash", line_color="red", row=1, col=1) | |
| fig2.add_hline(y=30, line_dash="dash", line_color="green", row=1, col=1) | |
| # MACD | |
| fig2.add_trace( | |
| go.Scatter(x=data.index, y=data['MACD'], name='MACD', | |
| line=dict(color='blue', width=1.5)), | |
| row=2, col=1 | |
| ) | |
| fig2.add_trace( | |
| go.Scatter(x=data.index, y=data['Signal_Line'], name='Signal', | |
| line=dict(color='orange', width=1.5)), | |
| row=2, col=1 | |
| ) | |
| # Bollinger Bands | |
| fig2.add_trace( | |
| go.Scatter(x=data.index, y=data[close_col], name='Price', | |
| line=dict(color='blue', width=2)), | |
| row=3, col=1 | |
| ) | |
| fig2.add_trace( | |
| go.Scatter(x=data.index, y=data['BB_upper'], name='Upper BB', | |
| line=dict(color='gray', dash='dash')), | |
| row=3, col=1 | |
| ) | |
| fig2.add_trace( | |
| go.Scatter(x=data.index, y=data['BB_lower'], name='Lower BB', | |
| line=dict(color='gray', dash='dash')), | |
| row=3, col=1 | |
| ) | |
| fig2.update_layout( | |
| height=400, | |
| showlegend=True, | |
| plot_bgcolor='white', | |
| paper_bgcolor='white' | |
| ) | |
| return [fig1, fig2] | |
| def chatbot_response(message: str, company: str, history: List[Tuple[str, str]]) -> List[Tuple[str, str]]: | |
| """Handle chatbot interactions""" | |
| router = ChatbotRouter(LlamaAgent()) | |
| response = router.route_and_respond(message, company) | |
| history = history + [(message, response)] | |
| return history | |
| # def create_interface(): | |
| # """Create Gradio interface""" | |
| # with gr.Blocks() as interface: | |
| # gr.Markdown("# Stock Analysis with Multi-Source News") | |
| # with gr.Row(): | |
| # with gr.Column(scale=2): | |
| # company = gr.Dropdown( | |
| # choices=list(COMPANIES.keys()), | |
| # value=list(COMPANIES.keys())[0], | |
| # label="Company" | |
| # ) | |
| # lookback = gr.Slider( | |
| # minimum=30, | |
| # maximum=365, | |
| # value=180, | |
| # step=1, | |
| # label="Analysis Period (days)" | |
| # ) | |
| # analyze_btn = gr.Button("Analyze", variant="primary") | |
| # with gr.Row(): | |
| # with gr.Column(scale=1): | |
| # chatbot = gr.Chatbot(label="Stock Assistant", height=400) | |
| # with gr.Row(): | |
| # msg = gr.Textbox( | |
| # label="Ask about company info, news, or prices", | |
| # scale=4 | |
| # ) | |
| # submit = gr.Button("Submit", scale=1) | |
| # clear = gr.Button("Clear", scale=1) | |
| # with gr.Column(scale=2): | |
| # analysis = gr.Textbox( | |
| # label="Technical Analysis Summary", | |
| # lines=10 | |
| # ) | |
| # chart1 = gr.Plot(label="Price and Volume Analysis") | |
| # chart2 = gr.Plot(label="Technical Indicators") | |
| # # Event handlers | |
| # analyze_btn.click( | |
| # fn=analyze_stock, | |
| # inputs=[company, lookback], | |
| # outputs=[analysis, chart1, chart2] | |
| # ) | |
| # submit.click( | |
| # fn=chatbot_response, | |
| # inputs=[msg, company, chatbot], | |
| # outputs=chatbot | |
| # ) | |
| # msg.submit( | |
| # fn=chatbot_response, | |
| # inputs=[msg, company, chatbot], | |
| # outputs=chatbot | |
| # ) | |
| # clear.click(lambda: None, None, chatbot, queue=False) | |
| # return interface | |
| def create_interface(): | |
| """Create Gradio interface""" | |
| with gr.Blocks() as interface: | |
| gr.Markdown("# Stock Analysis with Multi-Source News") | |
| # Top section with analysis components | |
| with gr.Row(): | |
| # Left column - Controls and Summary | |
| with gr.Column(scale=1): | |
| company = gr.Dropdown( | |
| choices=list(COMPANIES.keys()), | |
| value=list(COMPANIES.keys())[0], | |
| label="Company" | |
| ) | |
| lookback = gr.Slider( | |
| minimum=30, | |
| maximum=365, | |
| value=180, | |
| step=1, | |
| label="Analysis Period (days)" | |
| ) | |
| analyze_btn = gr.Button("Analyze", variant="primary") | |
| analysis = gr.Textbox( | |
| label="Technical Analysis Summary", | |
| lines=30 | |
| ) | |
| # Right column - Charts | |
| with gr.Column(scale=2): | |
| chart1 = gr.Plot(label="Price and Volume Analysis") | |
| chart2 = gr.Plot(label="Technical Indicators") | |
| gr.Markdown("---") # Separator | |
| # Bottom section - Chatbot | |
| with gr.Row(): | |
| chatbot = gr.Chatbot(label="Stock Assistant", height=400) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Ask about company info, news, or prices", | |
| scale=4 | |
| ) | |
| submit = gr.Button("Submit", scale=1) | |
| clear = gr.Button("Clear", scale=1) | |
| # Event handlers | |
| analyze_btn.click( | |
| fn=analyze_stock, | |
| inputs=[company, lookback], | |
| outputs=[analysis, chart1, chart2] | |
| ) | |
| submit.click( | |
| fn=chatbot_response, | |
| inputs=[msg, company, chatbot], | |
| outputs=chatbot | |
| ) | |
| msg.submit( | |
| fn=chatbot_response, | |
| inputs=[msg, company, chatbot], | |
| outputs=chatbot | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| return interface | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| interface.launch(debug=True) |