Spaces:
Sleeping
Sleeping
| from typing import List | |
| import html | |
| import asyncio | |
| from pydantic_models import Product | |
| from pydantic_ai_agents import Chatbot | |
| import pydantic_ai_agents | |
| import json | |
| import fastapi | |
| import uvicorn | |
| import logging | |
| import sys | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi import FastAPI, WebSocket | |
| # from starlette.middleware.cors import CORSMiddleware | |
| logging.basicConfig(stream=sys.stderr, level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"] | |
| ) | |
| async def ws_endpoint(ws: WebSocket): | |
| await ws.accept() | |
| user_id = id(ws) | |
| logger.info(f'New Websocket Connection: {user_id}') | |
| chatbot = Chatbot(top_k=12, ws=ws) | |
| async def handle_chat(user_query: id): | |
| try: | |
| resp = await chatbot.chat(user_query) | |
| dic = { | |
| 'type': 'response', | |
| 'message': resp.output.message, | |
| 'products': [p.model_dump() for p in resp.output.products if p], | |
| 'recc': resp.output.recommended.model_dump(), | |
| 'flow': resp.output.flow, | |
| 'steps': resp.output.steps | |
| } | |
| logger.info(f'Res ws sent: {dic}') | |
| await ws.send_json(dic) | |
| logger.info('Sending ackChat ...') | |
| await ws.send_json({ | |
| "type": "ackChat", | |
| "message": "Chat complete" | |
| }) | |
| logger.info('Sent ackChat !') | |
| except Exception as e: | |
| logger.error(f'chat oopsie: {e}') | |
| await ws.send_json({ | |
| "type": "error", | |
| "message": f"{e}" | |
| }) | |
| try: | |
| while True: | |
| data = await ws.receive_text() | |
| msg = json.loads(data) | |
| if msg['type'] == 'chat': | |
| user_query = msg['content'] | |
| asyncio.create_task(handle_chat(user_query)) | |
| elif msg['type'] == 'prompt_response': | |
| prompt_id = msg.get('prompt_id') | |
| isRec = chatbot.coach.set_userResp(msg['content'], prompt_id) | |
| await ws.send_json({ | |
| "type": "ackPromptUser", | |
| "message": "Response received" if isRec else "No matching prompt :(" | |
| }) | |
| except fastapi.WebSocketDisconnect: | |
| logger.error(f'Websocket disconnected: {user_id}') | |
| except Exception as e: | |
| logger.error(f'Oops: {e}') | |
| finally: | |
| await ws.close() | |
| # @app.get('/') | |
| # async def home(): | |
| # """Home Page Frontend""" | |
| # html_content = """ | |
| # """ | |
| # return fastapi.responses.HTMLResponse(html_content) | |
| app.mount("/", StaticFiles(directory='static', html=True), name='static') | |
| if __name__ == '__main__': | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |