from typing import List
import html
import asyncio
from pydantic_models import Product
from pydantic_ai_agents import Chatbot
import pydantic_ai_agents
import json
import fastapi
import uvicorn
import logging
import sys
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, WebSocket
# from starlette.middleware.cors import CORSMiddleware
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
@app.websocket('/ws/chat')
async def ws_endpoint(ws: WebSocket):
await ws.accept()
user_id = id(ws)
logger.info(f'New Websocket Connection: {user_id}')
chatbot = Chatbot(top_k=12, ws=ws)
async def handle_chat(user_query: id):
try:
resp = await chatbot.chat(user_query)
dic = {
'type': 'response',
'message': resp.output.message,
'products': [p.model_dump() for p in resp.output.products if p],
'recc': resp.output.recommended.model_dump(),
'flow': resp.output.flow,
'steps': resp.output.steps
}
logger.info(f'Res ws sent: {dic}')
await ws.send_json(dic)
logger.info('Sending ackChat ...')
await ws.send_json({
"type": "ackChat",
"message": "Chat complete"
})
logger.info('Sent ackChat !')
except Exception as e:
logger.error(f'chat oopsie: {e}')
await ws.send_json({
"type": "error",
"message": f"{e}"
})
try:
while True:
data = await ws.receive_text()
msg = json.loads(data)
if msg['type'] == 'chat':
user_query = msg['content']
asyncio.create_task(handle_chat(user_query))
elif msg['type'] == 'prompt_response':
prompt_id = msg.get('prompt_id')
isRec = chatbot.coach.set_userResp(msg['content'], prompt_id)
await ws.send_json({
"type": "ackPromptUser",
"message": "Response received" if isRec else "No matching prompt :("
})
except fastapi.WebSocketDisconnect:
logger.error(f'Websocket disconnected: {user_id}')
except Exception as e:
logger.error(f'Oops: {e}')
finally:
await ws.close()
# @app.get('/')
# async def home():
# """Home Page Frontend"""
# html_content = """
# """
# return fastapi.responses.HTMLResponse(html_content)
app.mount("/", StaticFiles(directory='static', html=True), name='static')
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=7860)