File size: 3,474 Bytes
600c235
c928d10
 
b61785d
c928d10
10c3afe
b61785d
 
37d460f
c928d10
37d460f
 
c928d10
 
 
 
 
10c3afe
c928d10
 
10c3afe
b61785d
 
10c3afe
c928d10
 
b61785d
 
 
 
 
 
c928d10
878dd56
c928d10
878dd56
 
 
c928d10
878dd56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b61785d
 
 
 
 
c928d10
 
 
c9027fd
10c3afe
 
 
 
 
b61785d
65c8966
c928d10
b61785d
c928d10
 
 
 
b61785d
878dd56
b61785d
878dd56
b61785d
c928d10
65c8966
 
c928d10
b61785d
 
c928d10
b61785d
c9027fd
b61785d
 
c928d10
 
 
b61785d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import chainlit as cl
from agent_graph import agent_node
from dotenv import load_dotenv
from typing import List, Dict
import time
import uuid
import asyncio

# Load environment variables from .env file
load_dotenv()

# Ensure your OpenAI API key is set up in environment variables
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
    raise ValueError("OpenAI API key is missing in the .env file")

# Store chat history with unique session IDs
chat_histories: Dict[str, List[Dict[str, str]]] = {}

def get_unique_session_id():
    """Generate a unique session ID using UUID."""
    return str(uuid.uuid4())

@cl.on_chat_start
async def start_chat():
    try:
        session_id = get_unique_session_id()
        cl.user_session.set("session_id", session_id)
        chat_histories[session_id] = []

        welcome_message = """📈 Welcome to the AI Stock Assistant!

I'm your intelligent stock market companion. Here's what I can do:

1. Get Real-Time Stock Prices 📊
   • Just type a ticker symbol (e.g., 'AAPL' for Apple)
   • Or ask naturally (e.g., "What's Microsoft's stock price?")

2. Calculate Share Purchases 💰
   • Ask how many shares you can buy (e.g., "How many TSLA shares for $5000?")
   • I'll show you the current price and number of shares

3. Smart Features 🧠
   • I understand company names and ticker symbols
   • I remember context from previous messages
   • I support all major stock exchanges

Popular stocks to try:
• Tech: AAPL (Apple), MSFT (Microsoft), GOOGL (Google)
• Finance: JPM (JPMorgan), BAC (Bank of America)
• Retail: WMT (Walmart), COST (Costco)

What would you like to know about the stock market?"""

        await cl.Message(content=welcome_message).send()
    except Exception as e:
        print(f"[Error] Failed to start chat: {e}")
        await cl.Message(content=f"⚠️ Error starting chat: {str(e)}").send()

@cl.on_message
async def handle_message(message: cl.Message):
    try:
        session_id = cl.user_session.get("session_id")
        if not session_id:
            session_id = get_unique_session_id()
            cl.user_session.set("session_id", session_id)
            chat_histories[session_id] = []

        history = chat_histories.get(session_id, [])
        history.append({"role": "user", "content": message.content})

        state = {
            "input": message.content,
            "chat_history": history
        }

        print(f"[Debug] Processing message: {message.content}")
        response = await asyncio.to_thread(agent_node, state)
        print(f"[Debug] Agent response: {response}")

        if isinstance(response, dict) and "output" in response:
            history.append({"role": "assistant", "content": response["output"]})
            await cl.Message(content=response["output"]).send()
        else:
            await cl.Message(content="❌ Received an invalid response format from the agent.").send()

        chat_histories[session_id] = history

    except Exception as e:
        print(f"[Error] Error in handle_message: {e}")
        await cl.Message(content=f"⚠️ Error: {str(e)}").send()

@cl.on_chat_end
async def end_chat():
    try:
        session_id = cl.user_session.get("session_id")
        if session_id and session_id in chat_histories:
            del chat_histories[session_id]
        cl.user_session.clear()
    except Exception as e:
        print(f"[Error] Failed to clean up chat history: {e}")