File size: 2,977 Bytes
a8bc862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cfe896
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import List
import html
import asyncio
from pydantic_models import Product
from pydantic_ai_agents import Chatbot
import pydantic_ai_agents
import json
import fastapi
import uvicorn
import logging
import sys
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, WebSocket
# from starlette.middleware.cors import CORSMiddleware

logging.basicConfig(stream=sys.stderr, level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"]
)

@app.websocket('/ws/chat')
async def ws_endpoint(ws: WebSocket):
    await ws.accept()
    user_id = id(ws) 
    logger.info(f'New Websocket Connection: {user_id}')
    chatbot = Chatbot(top_k=12, ws=ws)
    
    async def handle_chat(user_query: id):
        try:
            resp = await chatbot.chat(user_query)
            dic = {
                'type': 'response',
                'message': resp.output.message,
                'products': [p.model_dump() for p in resp.output.products if p],
                'recc': resp.output.recommended.model_dump(),
                'flow': resp.output.flow,
                'steps': resp.output.steps
            }
            logger.info(f'Res ws sent: {dic}')
            await ws.send_json(dic)
            logger.info('Sending ackChat ...')
            await ws.send_json({    
                "type": "ackChat",
                "message": "Chat complete"
            })
            logger.info('Sent ackChat !')                
        except Exception as e:
            logger.error(f'chat oopsie: {e}')
            await ws.send_json({
            "type": "error",
            "message": f"{e}"
        })
    
    try: 
        while True:
            data = await ws.receive_text()
            msg = json.loads(data)
            if msg['type'] == 'chat':
                user_query = msg['content']
                asyncio.create_task(handle_chat(user_query))
            elif msg['type'] == 'prompt_response':
                prompt_id = msg.get('prompt_id')
                isRec = chatbot.coach.set_userResp(msg['content'], prompt_id)
                await ws.send_json({
                    "type": "ackPromptUser",
                    "message": "Response received" if isRec else "No matching prompt :("
                })
                
    except fastapi.WebSocketDisconnect:
        logger.error(f'Websocket disconnected: {user_id}')
    except Exception as e:
        logger.error(f'Oops: {e}')
    finally:
        await ws.close()

# @app.get('/')
# async def home():
#     """Home Page Frontend"""
#     html_content = """
#     """
#     return fastapi.responses.HTMLResponse(html_content)
app.mount("/", StaticFiles(directory='static', html=True), name='static')

if __name__ == '__main__':
    uvicorn.run(app, host="0.0.0.0", port=7860)