import asyncio from contextlib import asynccontextmanager import os import sqlite3 from urllib.parse import urlsplit, urlunsplit from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from qdrant_client.http.exceptions import UnexpectedResponse from slowapi.errors import RateLimitExceeded from app.api.admin import router as admin_router from app.api.chat import router as chat_router from app.api.feedback import router as feedback_router from app.api.health import router as health_router from app.api.tts import router as tts_router from app.api.transcribe import router as transcribe_router from app.core.config import get_settings from app.core.exceptions import AppError from app.core.logging import get_logger from app.pipeline.graph import build_pipeline from app.security.rate_limiter import limiter, custom_rate_limit_handler from app.services.embedder import Embedder from app.services.gemini_client import GeminiClient from app.services.github_log import GithubLog from app.services.llm_client import get_llm_client, TpmBucket from app.services.reranker import Reranker from app.services.semantic_cache import SemanticCache from app.services.transcriber import GroqTranscriber from app.services.tts_client import TTSClient from app.services.conversation_store import ConversationStore from qdrant_client import QdrantClient logger = get_logger(__name__) def _is_qdrant_not_found(exc: Exception) -> bool: """Return True when Qdrant responded with HTTP 404.""" if isinstance(exc, UnexpectedResponse): status_code = getattr(exc, "status_code", None) if status_code == 404: return True message = str(exc) return "404" in message and "page not found" in message.lower() def _normalize_qdrant_url(url: str) -> str: """ Normalize QDRANT_URL to an API base URL. If the configured URL includes a non-root path (for example, a dashboard URL), strip the path and keep scheme + host(+port) only. """ raw = (url or "").strip().rstrip("/") if not raw: return raw if "://" not in raw: scheme = "http" if raw.startswith(("localhost", "127.0.0.1")) else "https" raw = f"{scheme}://{raw}" parsed = urlsplit(raw) if not parsed.netloc: return raw if parsed.path and parsed.path != "/": return urlunsplit((parsed.scheme, parsed.netloc, "", "", "")).rstrip("/") return raw def _sqlite_row_count(db_path: str) -> int: """Return the current interactions row count, or 0 if the table doesn't exist.""" try: with sqlite3.connect(db_path) as conn: return conn.execute("SELECT COUNT(*) FROM interactions").fetchone()[0] except sqlite3.OperationalError: return 0 except Exception: return 0 async def _qdrant_keepalive_loop( qdrant: QdrantClient, interval_seconds: int, stop_event: asyncio.Event, ) -> None: """ Periodically ping Qdrant so the deployment keeps an active connection. Uses asyncio.to_thread because qdrant-client methods are synchronous. """ if interval_seconds <= 0: return while not stop_event.is_set(): try: await asyncio.wait_for(stop_event.wait(), timeout=interval_seconds) break except TimeoutError: pass try: await asyncio.to_thread(qdrant.get_collections) logger.info("Qdrant keepalive ping succeeded") except Exception as exc: logger.warning("Qdrant keepalive ping failed: %s", exc) @asynccontextmanager async def lifespan(app: FastAPI): settings = get_settings() logger.info("Starting PersonaBot API | env=%s", settings.ENVIRONMENT) # Durable GitHub interaction log — survives HF Space restarts. # When PERSONABOT_WRITE_TOKEN is not set (local dev), GithubLog.enabled=False # and all append calls are silent no-ops. github_log = GithubLog( write_token=settings.PERSONABOT_WRITE_TOKEN or "", repo=settings.PERSONABOT_REPO, ) app.state.github_log = github_log # Attach the in-memory semantic cache. No external service required. app.state.semantic_cache = SemanticCache( max_size=settings.SEMANTIC_CACHE_SIZE, ttl_seconds=settings.SEMANTIC_CACHE_TTL_SECONDS, similarity_threshold=settings.SEMANTIC_CACHE_SIMILARITY_THRESHOLD, ) app.state.conversation_store = ConversationStore(settings.DB_PATH, github_log=github_log) # Issue 1: reconstruct SQLite conversation history from the durable GitHub log # after an ephemeral HF Space restart. Only triggers when SQLite is empty # (<10 rows) so a healthy Space with accumulated data is never overwritten. if github_log.enabled and _sqlite_row_count(settings.DB_PATH) < 10: logger.info("SQLite appears empty — attempting reconstruction from durable log.") recent = await github_log.load_recent(500) if recent: app.state.conversation_store.populate_from_records(recent) # DagsHub/MLflow experiment tracking — optional, only active when token is set. # In prod with DAGSHUB_TOKEN set, experiments are tracked at dagshub.com. # In local or test environments, MLflow is a no-op. if settings.DAGSHUB_TOKEN: import dagshub dagshub.init( repo_owner=settings.DAGSHUB_REPO.split("/")[0], repo_name=settings.DAGSHUB_REPO.split("/")[1], mlflow=True, dvc=False, ) logger.info("DagsHub MLflow tracking enabled | repo=%s", settings.DAGSHUB_REPO) embedder = Embedder(remote_url=settings.EMBEDDER_URL, environment=settings.ENVIRONMENT) reranker = Reranker(remote_url=settings.RERANKER_URL, environment=settings.ENVIRONMENT) gemini_client = GeminiClient( api_key=settings.GEMINI_API_KEY or "", model=settings.GEMINI_MODEL, context_path=settings.GEMINI_CONTEXT_PATH, ) app.state.gemini_client = gemini_client app.state.transcriber = GroqTranscriber( api_key=settings.GROQ_API_KEY or "", model=settings.GROQ_TRANSCRIBE_MODEL, timeout_seconds=settings.TRANSCRIBE_TIMEOUT_SECONDS, ) app.state.tts_client = TTSClient( tts_space_url=settings.TTS_SPACE_URL, timeout_seconds=settings.TRANSCRIBE_TIMEOUT_SECONDS, ) from app.services.vector_store import VectorStore from app.security.guard_classifier import GuardClassifier qdrant_url = (settings.QDRANT_URL or "").strip() qdrant = QdrantClient( url=qdrant_url, api_key=settings.QDRANT_API_KEY, timeout=60, ) vector_store = VectorStore(qdrant, settings.QDRANT_COLLECTION) # Idempotent: creates collection if absent so a cold-start before first # ingest run doesn't crash every search with "collection not found". try: vector_store.ensure_collection() except UnexpectedResponse as exc: fallback_url = _normalize_qdrant_url(qdrant_url) if _is_qdrant_not_found(exc) and fallback_url and fallback_url != qdrant_url: logger.warning( "Qdrant URL returned 404, retrying with normalized root URL | original=%s normalized=%s", qdrant_url, fallback_url, ) qdrant.close() qdrant = QdrantClient( url=fallback_url, api_key=settings.QDRANT_API_KEY, timeout=60, ) vector_store = VectorStore(qdrant, settings.QDRANT_COLLECTION) vector_store.ensure_collection() else: raise # Issue 7: shared TPM bucket tracks token consumption across the current 60s # window. Injected into GroqClient so it can downgrade 70B → 8B automatically # when the bucket is above 12,000 tokens, preventing hard rate-limit failures. tpm_bucket = TpmBucket() llm_client = get_llm_client(settings, tpm_bucket=tpm_bucket) # Expose llm_client on app state so chat.py can use it for follow-up # question generation without re-constructing the client per request. app.state.llm_client = llm_client app.state.pipeline = build_pipeline({ "classifier": GuardClassifier(), "cache": app.state.semantic_cache, "embedder": embedder, "gemini": gemini_client, "llm": llm_client, "vector_store": vector_store, "reranker": reranker, "db_path": settings.DB_PATH, "github_log": github_log, }) app.state.settings = settings app.state.qdrant = qdrant keepalive_stop = asyncio.Event() keepalive_task = asyncio.create_task( _qdrant_keepalive_loop( qdrant=qdrant, interval_seconds=settings.QDRANT_KEEPALIVE_SECONDS, stop_event=keepalive_stop, ) ) app.state.qdrant_keepalive_stop = keepalive_stop app.state.qdrant_keepalive_task = keepalive_task logger.info("Startup complete") yield logger.info("Shutting down") app.state.qdrant_keepalive_stop.set() try: await asyncio.wait_for(app.state.qdrant_keepalive_task, timeout=2) except TimeoutError: app.state.qdrant_keepalive_task.cancel() except Exception: pass app.state.semantic_cache = None app.state.qdrant.close() # Only attempt to end an MLflow run when DagsHub tracking was enabled at startup. if settings.DAGSHUB_TOKEN: import mlflow if mlflow.active_run(): mlflow.end_run() def create_app() -> FastAPI: app = FastAPI( title="PersonaBot API", lifespan=lifespan, docs_url=None, redoc_url=None, openapi_url=None, ) app.state.limiter = limiter settings = get_settings() origins = [settings.ALLOWED_ORIGIN] if settings.ENVIRONMENT in ("local", "staging", "test"): origins.append("http://localhost:3000") app.docs_url = "/docs" app.redoc_url = "/redoc" app.openapi_url = "/openapi.json" app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["POST", "GET", "OPTIONS"], allow_headers=["Content-Type", "Authorization"], ) @app.exception_handler(AppError) async def app_error_handler(request: Request, exc: AppError) -> JSONResponse: logger.error("AppError: %s", exc.message, extra={"context": exc.context}) return JSONResponse(status_code=400, content={"error": exc.message}) @app.exception_handler(Exception) async def global_error_handler(request: Request, exc: Exception) -> JSONResponse: logger.error("Unhandled exception", exc_info=exc) return JSONResponse(status_code=500, content={"error": "Internal Server Error"}) app.add_exception_handler(RateLimitExceeded, custom_rate_limit_handler) app.include_router(health_router, tags=["Health"]) app.include_router(chat_router, prefix="/chat", tags=["Chat"]) app.include_router(transcribe_router, prefix="/transcribe", tags=["Transcribe"]) app.include_router(tts_router, prefix="/tts", tags=["TTS"]) app.include_router(feedback_router, prefix="/chat", tags=["Feedback"]) app.include_router(admin_router, prefix="/admin", tags=["Admin"]) return app app = create_app()