Spaces:
Build error
Build error
File size: 7,992 Bytes
27c4848 abb8605 7cbc944 c6501a9 d2e53aa c2837fb 1df6b88 27c4848 0548369 27c4848 0548369 c6501a9 0548369 c6501a9 27c4848 c6501a9 27c4848 1df6b88 27c4848 1df6b88 d2e53aa 1df6b88 d2e53aa 1df6b88 27c4848 1df6b88 27c4848 1df6b88 27c4848 1df6b88 27c4848 1df6b88 c2837fb 1df6b88 c2837fb 1df6b88 27c4848 c6501a9 1df6b88 7cbc944 27c4848 c6501a9 27c4848 7cbc944 c6501a9 1df6b88 c6501a9 7cbc944 c6501a9 7cbc944 c6501a9 27c4848 1df6b88 27c4848 d2e53aa 27c4848 c6501a9 27c4848 d2e53aa 27c4848 c6501a9 373c86e d2e53aa c6501a9 373c86e 7cbc944 c6501a9 373c86e 27c4848 c6501a9 27c4848 c6501a9 27c4848 c6501a9 5474c75 c6501a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
# app.py
from typing import Annotated, Any, Dict, List, Literal, Sequence, TypedDict, Union
import os
import asyncio
from operator import itemgetter
from dotenv import load_dotenv
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
import uvicorn
import json
from langchain_core.messages import HumanMessage, AIMessage
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage
import logging
import sys
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
from openai import OpenAI
import yfinance as yf
# LangChain that kinda works, but not really
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
# Configure logging for sanity
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Load environment variables OPENAI_API_KEY
logger.info("Loading environment variables...")
load_dotenv()
# Setup FastAPI app
logger.info("Initializing FastAPI application...")
app = FastAPI()
# Define stock-related tools
@tool
def get_stock_price(symbol: str) -> str:
"""Get the current stock price for a given symbol."""
logger.info(f"Getting stock price for {symbol}")
try:
stock = yf.Ticker(symbol)
price = stock.info.get('currentPrice', 'N/A')
return f"The current price of {symbol} is ${price}"
except Exception as e:
logger.error(f"Error getting stock price: {str(e)}")
return f"Error getting stock price: {str(e)}"
@tool
def get_stock_info(symbol: str) -> str:
"""Get basic information about a stock."""
logger.info(f"Getting stock info for {symbol}")
try:
stock = yf.Ticker(symbol)
info = stock.info
return f"""
{symbol} Information:
- Company Name: {info.get('longName', 'N/A')}
- Sector: {info.get('sector', 'N/A')}
- Market Cap: ${info.get('marketCap', 'N/A'):,.2f}
- 52 Week High: ${info.get('fiftyTwoWeekHigh', 'N/A')}
- 52 Week Low: ${info.get('fiftyTwoWeekLow', 'N/A')}
"""
except Exception as e:
logger.error(f"Error getting stock info: {str(e)}")
return f"Error getting stock info: {str(e)}"
@tool
def get_stock_history(symbol: str, period: str = "1mo") -> str:
"""Get historical stock data for a given period."""
logger.info(f"Getting stock history for {symbol} over {period}")
try:
stock = yf.Ticker(symbol)
hist = stock.history(period=period)
if hist.empty:
return f"No historical data found for {symbol}"
latest = hist.iloc[-1]
return f"""
{symbol} Historical Data ({period}):
- Latest Close: ${latest['Close']:.2f}
- High: ${latest['High']:.2f}
- Low: ${latest['Low']:.2f}
- Volume: {latest['Volume']:,}
"""
except Exception as e:
logger.error(f"Error getting stock history: {str(e)}")
return f"Error getting stock history: {str(e)}"
logger.info("Creating tools list...")
tools = [get_stock_price, get_stock_info, get_stock_history]
# Set up the language model
logger.info("Initializing ChatOpenAI model...")
try:
model = ChatOpenAI(temperature=0.5)
logger.info("ChatOpenAI model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize ChatOpenAI model: {str(e)}")
raise
# Create the prompt template
logger.info("Creating prompt template...")
try:
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content="You are a helpful AI assistant specialized in stock market information. Use the available tools to provide accurate stock data."),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
logger.info("Prompt template created successfully")
except Exception as e:
logger.error(f"Failed to create prompt template: {str(e)}")
raise
# Create the agent
logger.info("Creating OpenAI functions agent...")
try:
agent = create_openai_functions_agent(model, tools, prompt)
logger.info("Agent created successfully")
except Exception as e:
logger.error(f"Failed to create agent: {str(e)}")
raise
# Create the agent executor
logger.info("Creating agent executor...")
try:
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
logger.info("Agent executor created successfully")
except Exception as e:
logger.error(f"Failed to create agent executor: {str(e)}")
raise
# Define agent nodes
def create_agent_node():
# System prompt for the agent
system_prompt = """You are a helpful AI assistant with access to the following tools:
1. get_stock_price: Get the current stock price for a given symbol
2. get_stock_info: Get basic information about a stock
3. get_stock_history: Get historical stock data for a given period
Use these tools to assist the user with their requests. When a tool is needed, call the appropriate function.
"""
# Create the prompt template
prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(content=system_prompt),
MessagesPlaceholder(variable_name="messages"),
]
)
# Create the agent
return prompt | model.bind_tools(tools=tools)
# WebSocket endpoint for real-time communication
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
logger.info("New WebSocket connection established")
await websocket.accept()
chat_history = [] # Initialize chat history
try:
while True:
try:
data = await websocket.receive_text()
logger.info(f"Received message: {data}")
# Add user message to chat history
chat_history.append(HumanMessage(content=data))
# Process the message with the agent, including chat history
response = agent_executor.invoke({
"input": data,
"chat_history": chat_history
})
# Add AI response to chat history
chat_history.append(AIMessage(content=response["output"]))
logger.info(f"Agent response: {response['output']}")
await websocket.send_json({"type": "ai_message", "content": response["output"]})
except Exception as e:
logger.error(f"Error processing message: {str(e)}")
await websocket.send_json({"type": "error", "content": f"Error processing message: {str(e)}"})
except WebSocketDisconnect:
logger.info("Client disconnected")
except Exception as e:
logger.error(f"WebSocket error: {str(e)}")
try:
await websocket.close()
except:
pass
# Serve the HTML frontend
@app.get("/")
async def get():
logger.info("Serving index.html")
return FileResponse("index.html")
# Mount static files
logger.info("Mounting static files...")
app.mount("/static", StaticFiles(directory="static"), name="static")
if __name__ == "__main__":
logger.info("Starting uvicorn server...")
try:
uvicorn.run(app, host="0.0.0.0", port=7860)
except Exception as e:
logger.error(f"Failed to start uvicorn server: {str(e)}")
raise
|