Spaces:
Build error
Build error
| # app.py | |
| from typing import Annotated, Any, Dict, List, Literal, Sequence, TypedDict, Union | |
| import os | |
| import asyncio | |
| from operator import itemgetter | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| import uvicorn | |
| import json | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.tools import tool | |
| from langchain.agents import AgentExecutor, create_openai_functions_agent | |
| from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain.schema import SystemMessage | |
| import logging | |
| import sys | |
| from langchain_community.utilities import DuckDuckGoSearchAPIWrapper | |
| from openai import OpenAI | |
| import yfinance as yf | |
| # LangChain that kinda works, but not really | |
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.output_parsers import JsonOutputParser | |
| from langchain_core.tools import tool | |
| from langchain_openai import ChatOpenAI | |
| # Configure logging for sanity | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout) | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables OPENAI_API_KEY | |
| logger.info("Loading environment variables...") | |
| load_dotenv() | |
| # Setup FastAPI app | |
| logger.info("Initializing FastAPI application...") | |
| app = FastAPI() | |
| # Define stock-related tools | |
| def get_stock_price(symbol: str) -> str: | |
| """Get the current stock price for a given symbol.""" | |
| logger.info(f"Getting stock price for {symbol}") | |
| try: | |
| stock = yf.Ticker(symbol) | |
| price = stock.info.get('currentPrice', 'N/A') | |
| return f"The current price of {symbol} is ${price}" | |
| except Exception as e: | |
| logger.error(f"Error getting stock price: {str(e)}") | |
| return f"Error getting stock price: {str(e)}" | |
| def get_stock_info(symbol: str) -> str: | |
| """Get basic information about a stock.""" | |
| logger.info(f"Getting stock info for {symbol}") | |
| try: | |
| stock = yf.Ticker(symbol) | |
| info = stock.info | |
| return f""" | |
| {symbol} Information: | |
| - Company Name: {info.get('longName', 'N/A')} | |
| - Sector: {info.get('sector', 'N/A')} | |
| - Market Cap: ${info.get('marketCap', 'N/A'):,.2f} | |
| - 52 Week High: ${info.get('fiftyTwoWeekHigh', 'N/A')} | |
| - 52 Week Low: ${info.get('fiftyTwoWeekLow', 'N/A')} | |
| """ | |
| except Exception as e: | |
| logger.error(f"Error getting stock info: {str(e)}") | |
| return f"Error getting stock info: {str(e)}" | |
| def get_stock_history(symbol: str, period: str = "1mo") -> str: | |
| """Get historical stock data for a given period.""" | |
| logger.info(f"Getting stock history for {symbol} over {period}") | |
| try: | |
| stock = yf.Ticker(symbol) | |
| hist = stock.history(period=period) | |
| if hist.empty: | |
| return f"No historical data found for {symbol}" | |
| latest = hist.iloc[-1] | |
| return f""" | |
| {symbol} Historical Data ({period}): | |
| - Latest Close: ${latest['Close']:.2f} | |
| - High: ${latest['High']:.2f} | |
| - Low: ${latest['Low']:.2f} | |
| - Volume: {latest['Volume']:,} | |
| """ | |
| except Exception as e: | |
| logger.error(f"Error getting stock history: {str(e)}") | |
| return f"Error getting stock history: {str(e)}" | |
| logger.info("Creating tools list...") | |
| tools = [get_stock_price, get_stock_info, get_stock_history] | |
| # Set up the language model | |
| logger.info("Initializing ChatOpenAI model...") | |
| try: | |
| model = ChatOpenAI(temperature=0.5) | |
| logger.info("ChatOpenAI model initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize ChatOpenAI model: {str(e)}") | |
| raise | |
| # Create the prompt template | |
| logger.info("Creating prompt template...") | |
| try: | |
| prompt = ChatPromptTemplate.from_messages([ | |
| SystemMessage(content="You are a helpful AI assistant specialized in stock market information. Use the available tools to provide accurate stock data."), | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("human", "{input}"), | |
| MessagesPlaceholder(variable_name="agent_scratchpad"), | |
| ]) | |
| logger.info("Prompt template created successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to create prompt template: {str(e)}") | |
| raise | |
| # Create the agent | |
| logger.info("Creating OpenAI functions agent...") | |
| try: | |
| agent = create_openai_functions_agent(model, tools, prompt) | |
| logger.info("Agent created successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to create agent: {str(e)}") | |
| raise | |
| # Create the agent executor | |
| logger.info("Creating agent executor...") | |
| try: | |
| agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) | |
| logger.info("Agent executor created successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to create agent executor: {str(e)}") | |
| raise | |
| # Define agent nodes | |
| def create_agent_node(): | |
| # System prompt for the agent | |
| system_prompt = """You are a helpful AI assistant with access to the following tools: | |
| 1. get_stock_price: Get the current stock price for a given symbol | |
| 2. get_stock_info: Get basic information about a stock | |
| 3. get_stock_history: Get historical stock data for a given period | |
| Use these tools to assist the user with their requests. When a tool is needed, call the appropriate function. | |
| """ | |
| # Create the prompt template | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| SystemMessage(content=system_prompt), | |
| MessagesPlaceholder(variable_name="messages"), | |
| ] | |
| ) | |
| # Create the agent | |
| return prompt | model.bind_tools(tools=tools) | |
| # WebSocket endpoint for real-time communication | |
| async def websocket_endpoint(websocket: WebSocket): | |
| logger.info("New WebSocket connection established") | |
| await websocket.accept() | |
| chat_history = [] # Initialize chat history | |
| try: | |
| while True: | |
| try: | |
| data = await websocket.receive_text() | |
| logger.info(f"Received message: {data}") | |
| # Add user message to chat history | |
| chat_history.append(HumanMessage(content=data)) | |
| # Process the message with the agent, including chat history | |
| response = agent_executor.invoke({ | |
| "input": data, | |
| "chat_history": chat_history | |
| }) | |
| # Add AI response to chat history | |
| chat_history.append(AIMessage(content=response["output"])) | |
| logger.info(f"Agent response: {response['output']}") | |
| await websocket.send_json({"type": "ai_message", "content": response["output"]}) | |
| except Exception as e: | |
| logger.error(f"Error processing message: {str(e)}") | |
| await websocket.send_json({"type": "error", "content": f"Error processing message: {str(e)}"}) | |
| except WebSocketDisconnect: | |
| logger.info("Client disconnected") | |
| except Exception as e: | |
| logger.error(f"WebSocket error: {str(e)}") | |
| try: | |
| await websocket.close() | |
| except: | |
| pass | |
| # Serve the HTML frontend | |
| async def get(): | |
| logger.info("Serving index.html") | |
| return FileResponse("index.html") | |
| # Mount static files | |
| logger.info("Mounting static files...") | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| if __name__ == "__main__": | |
| logger.info("Starting uvicorn server...") | |
| try: | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |
| except Exception as e: | |
| logger.error(f"Failed to start uvicorn server: {str(e)}") | |
| raise | |