Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, HTMLResponse, RedirectResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import ValidationError | |
| from src.models import ItineraryPlan, EvaluationResult, UserRequest, PlanResponse, SystemResponse | |
| from src.graph import build_qiddiya_graph, QiddiyaState | |
| from src.db import init_db, save_plan_result, save_evaluation_result | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", | |
| ) | |
| logger = logging.getLogger("qiddiya.app") | |
| app = FastAPI(title="Qiddiya Smart Guide", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| _base = Path(__file__).parent | |
| app.mount("/static", StaticFiles(directory=str(_base / "static")), name="static") | |
| graph = build_qiddiya_graph() | |
| graph_app = graph.compile() | |
| async def on_startup() -> None: | |
| base = Path(__file__).parent | |
| data_dir = base / "data" | |
| data_dir.mkdir(exist_ok=True) | |
| init_db() | |
| logger.info("Qiddiya Smart Guide backend started.") | |
| async def health() -> Dict[str, str]: | |
| return {"status": "ok"} | |
| async def plan_itinerary(request: UserRequest) -> Any: | |
| try: | |
| initial_state: QiddiyaState = { | |
| "user_request": request.model_dump(), | |
| "logs": [], | |
| "wait_time_forecast": None, | |
| "raw_plan": None, | |
| "final_plan": None, | |
| "critique": None, | |
| "reflection_round": 0, | |
| } | |
| result_state = graph_app.invoke(initial_state) | |
| plan = ItineraryPlan.model_validate(result_state["final_plan"]) | |
| save_plan_result(request, plan) | |
| system = SystemResponse( | |
| logs=result_state.get("logs") or [], | |
| reflection_round=int(result_state.get("reflection_round") or 0), | |
| critique=(result_state.get("critique") or ""), | |
| wait_time_forecast=result_state.get("wait_time_forecast"), | |
| ) | |
| return PlanResponse(plan=plan, system=system) | |
| except ValidationError as ve: | |
| logger.exception("Validation error during planning") | |
| return JSONResponse( | |
| status_code=422, | |
| content={"detail": ve.errors()}, | |
| ) | |
| except Exception: | |
| logger.exception("Unexpected error during planning") | |
| return JSONResponse( | |
| status_code=500, | |
| content={"detail": "Internal server error during planning"}, | |
| ) | |
| async def evaluate_plan(request: UserRequest) -> Any: | |
| try: | |
| initial_state: QiddiyaState = { | |
| "user_request": request.model_dump(), | |
| "logs": [], | |
| "wait_time_forecast": None, | |
| "raw_plan": None, | |
| "final_plan": None, | |
| "critique": None, | |
| "reflection_round": 0, | |
| } | |
| result_state = graph_app.invoke(initial_state) | |
| plan = ItineraryPlan.model_validate(result_state["final_plan"]) | |
| from src.evaluation import evaluate_itinerary | |
| evaluation = evaluate_itinerary(request=request, plan=plan) | |
| save_evaluation_result(request, plan, evaluation) | |
| return evaluation | |
| except ValidationError as ve: | |
| logger.exception("Validation error during evaluation") | |
| return JSONResponse( | |
| status_code=422, | |
| content={"detail": ve.errors()}, | |
| ) | |
| except Exception: | |
| logger.exception("Unexpected error during evaluation") | |
| return JSONResponse( | |
| status_code=500, | |
| content={"detail": "Internal server error during evaluation"}, | |
| ) | |
| async def redirect_gradio() -> RedirectResponse: | |
| """Redirect old /gradio bookmarks to the main UI.""" | |
| return RedirectResponse(url="/", status_code=302) | |
| async def root() -> HTMLResponse: | |
| """Serve a single-page, modern HTML UI.""" | |
| html = (_base / "templates" / "index.html").read_text(encoding="utf-8") | |
| return HTMLResponse(content=html) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", "7860")) | |
| uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False) | |