Spaces:
Running
Running
| import asyncio | |
| from contextlib import asynccontextmanager | |
| import os | |
| import sqlite3 | |
| from urllib.parse import urlsplit, urlunsplit | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from qdrant_client.http.exceptions import UnexpectedResponse | |
| from slowapi.errors import RateLimitExceeded | |
| from app.api.admin import router as admin_router | |
| from app.api.chat import router as chat_router | |
| from app.api.feedback import router as feedback_router | |
| from app.api.health import router as health_router | |
| from app.api.tts import router as tts_router | |
| from app.api.transcribe import router as transcribe_router | |
| from app.core.config import get_settings | |
| from app.core.exceptions import AppError | |
| from app.core.logging import get_logger | |
| from app.pipeline.graph import build_pipeline | |
| from app.security.rate_limiter import limiter, custom_rate_limit_handler | |
| from app.services.embedder import Embedder | |
| from app.services.gemini_client import GeminiClient | |
| from app.services.github_log import GithubLog | |
| from app.services.llm_client import get_llm_client, TpmBucket | |
| from app.services.reranker import Reranker | |
| from app.services.semantic_cache import SemanticCache | |
| from app.services.transcriber import GroqTranscriber | |
| from app.services.tts_client import TTSClient | |
| from app.services.conversation_store import ConversationStore | |
| from qdrant_client import QdrantClient | |
| logger = get_logger(__name__) | |
| def _is_qdrant_not_found(exc: Exception) -> bool: | |
| """Return True when Qdrant responded with HTTP 404.""" | |
| if isinstance(exc, UnexpectedResponse): | |
| status_code = getattr(exc, "status_code", None) | |
| if status_code == 404: | |
| return True | |
| message = str(exc) | |
| return "404" in message and "page not found" in message.lower() | |
| def _normalize_qdrant_url(url: str) -> str: | |
| """ | |
| Normalize QDRANT_URL to an API base URL. | |
| If the configured URL includes a non-root path (for example, a dashboard | |
| URL), strip the path and keep scheme + host(+port) only. | |
| """ | |
| raw = (url or "").strip().rstrip("/") | |
| if not raw: | |
| return raw | |
| if "://" not in raw: | |
| scheme = "http" if raw.startswith(("localhost", "127.0.0.1")) else "https" | |
| raw = f"{scheme}://{raw}" | |
| parsed = urlsplit(raw) | |
| if not parsed.netloc: | |
| return raw | |
| if parsed.path and parsed.path != "/": | |
| return urlunsplit((parsed.scheme, parsed.netloc, "", "", "")).rstrip("/") | |
| return raw | |
| def _sqlite_row_count(db_path: str) -> int: | |
| """Return the current interactions row count, or 0 if the table doesn't exist.""" | |
| try: | |
| with sqlite3.connect(db_path) as conn: | |
| return conn.execute("SELECT COUNT(*) FROM interactions").fetchone()[0] | |
| except sqlite3.OperationalError: | |
| return 0 | |
| except Exception: | |
| return 0 | |
| async def _qdrant_keepalive_loop( | |
| qdrant: QdrantClient, | |
| interval_seconds: int, | |
| stop_event: asyncio.Event, | |
| ) -> None: | |
| """ | |
| Periodically ping Qdrant so the deployment keeps an active connection. | |
| Uses asyncio.to_thread because qdrant-client methods are synchronous. | |
| """ | |
| if interval_seconds <= 0: | |
| return | |
| while not stop_event.is_set(): | |
| try: | |
| await asyncio.wait_for(stop_event.wait(), timeout=interval_seconds) | |
| break | |
| except TimeoutError: | |
| pass | |
| try: | |
| await asyncio.to_thread(qdrant.get_collections) | |
| logger.info("Qdrant keepalive ping succeeded") | |
| except Exception as exc: | |
| logger.warning("Qdrant keepalive ping failed: %s", exc) | |
| async def lifespan(app: FastAPI): | |
| settings = get_settings() | |
| logger.info("Starting PersonaBot API | env=%s", settings.ENVIRONMENT) | |
| # Durable GitHub interaction log — survives HF Space restarts. | |
| # When PERSONABOT_WRITE_TOKEN is not set (local dev), GithubLog.enabled=False | |
| # and all append calls are silent no-ops. | |
| github_log = GithubLog( | |
| write_token=settings.PERSONABOT_WRITE_TOKEN or "", | |
| repo=settings.PERSONABOT_REPO, | |
| ) | |
| app.state.github_log = github_log | |
| # Attach the in-memory semantic cache. No external service required. | |
| app.state.semantic_cache = SemanticCache( | |
| max_size=settings.SEMANTIC_CACHE_SIZE, | |
| ttl_seconds=settings.SEMANTIC_CACHE_TTL_SECONDS, | |
| similarity_threshold=settings.SEMANTIC_CACHE_SIMILARITY_THRESHOLD, | |
| ) | |
| app.state.conversation_store = ConversationStore(settings.DB_PATH, github_log=github_log) | |
| # Issue 1: reconstruct SQLite conversation history from the durable GitHub log | |
| # after an ephemeral HF Space restart. Only triggers when SQLite is empty | |
| # (<10 rows) so a healthy Space with accumulated data is never overwritten. | |
| if github_log.enabled and _sqlite_row_count(settings.DB_PATH) < 10: | |
| logger.info("SQLite appears empty — attempting reconstruction from durable log.") | |
| recent = await github_log.load_recent(500) | |
| if recent: | |
| app.state.conversation_store.populate_from_records(recent) | |
| # DagsHub/MLflow experiment tracking — optional, only active when token is set. | |
| # In prod with DAGSHUB_TOKEN set, experiments are tracked at dagshub.com. | |
| # In local or test environments, MLflow is a no-op. | |
| if settings.DAGSHUB_TOKEN: | |
| import dagshub | |
| dagshub.init( | |
| repo_owner=settings.DAGSHUB_REPO.split("/")[0], | |
| repo_name=settings.DAGSHUB_REPO.split("/")[1], | |
| mlflow=True, | |
| dvc=False, | |
| ) | |
| logger.info("DagsHub MLflow tracking enabled | repo=%s", settings.DAGSHUB_REPO) | |
| embedder = Embedder(remote_url=settings.EMBEDDER_URL, environment=settings.ENVIRONMENT) | |
| reranker = Reranker(remote_url=settings.RERANKER_URL, environment=settings.ENVIRONMENT) | |
| gemini_client = GeminiClient( | |
| api_key=settings.GEMINI_API_KEY or "", | |
| model=settings.GEMINI_MODEL, | |
| context_path=settings.GEMINI_CONTEXT_PATH, | |
| ) | |
| app.state.gemini_client = gemini_client | |
| app.state.transcriber = GroqTranscriber( | |
| api_key=settings.GROQ_API_KEY or "", | |
| model=settings.GROQ_TRANSCRIBE_MODEL, | |
| timeout_seconds=settings.TRANSCRIBE_TIMEOUT_SECONDS, | |
| ) | |
| app.state.tts_client = TTSClient( | |
| tts_space_url=settings.TTS_SPACE_URL, | |
| timeout_seconds=settings.TRANSCRIBE_TIMEOUT_SECONDS, | |
| ) | |
| from app.services.vector_store import VectorStore | |
| from app.security.guard_classifier import GuardClassifier | |
| qdrant_url = (settings.QDRANT_URL or "").strip() | |
| qdrant = QdrantClient( | |
| url=qdrant_url, | |
| api_key=settings.QDRANT_API_KEY, | |
| timeout=60, | |
| ) | |
| vector_store = VectorStore(qdrant, settings.QDRANT_COLLECTION) | |
| # Idempotent: creates collection if absent so a cold-start before first | |
| # ingest run doesn't crash every search with "collection not found". | |
| try: | |
| vector_store.ensure_collection() | |
| except UnexpectedResponse as exc: | |
| fallback_url = _normalize_qdrant_url(qdrant_url) | |
| if _is_qdrant_not_found(exc) and fallback_url and fallback_url != qdrant_url: | |
| logger.warning( | |
| "Qdrant URL returned 404, retrying with normalized root URL | original=%s normalized=%s", | |
| qdrant_url, | |
| fallback_url, | |
| ) | |
| qdrant.close() | |
| qdrant = QdrantClient( | |
| url=fallback_url, | |
| api_key=settings.QDRANT_API_KEY, | |
| timeout=60, | |
| ) | |
| vector_store = VectorStore(qdrant, settings.QDRANT_COLLECTION) | |
| vector_store.ensure_collection() | |
| else: | |
| raise | |
| # Issue 7: shared TPM bucket tracks token consumption across the current 60s | |
| # window. Injected into GroqClient so it can downgrade 70B → 8B automatically | |
| # when the bucket is above 12,000 tokens, preventing hard rate-limit failures. | |
| tpm_bucket = TpmBucket() | |
| llm_client = get_llm_client(settings, tpm_bucket=tpm_bucket) | |
| # Expose llm_client on app state so chat.py can use it for follow-up | |
| # question generation without re-constructing the client per request. | |
| app.state.llm_client = llm_client | |
| app.state.pipeline = build_pipeline({ | |
| "classifier": GuardClassifier(), | |
| "cache": app.state.semantic_cache, | |
| "embedder": embedder, | |
| "gemini": gemini_client, | |
| "llm": llm_client, | |
| "vector_store": vector_store, | |
| "reranker": reranker, | |
| "db_path": settings.DB_PATH, | |
| "github_log": github_log, | |
| }) | |
| app.state.settings = settings | |
| app.state.qdrant = qdrant | |
| keepalive_stop = asyncio.Event() | |
| keepalive_task = asyncio.create_task( | |
| _qdrant_keepalive_loop( | |
| qdrant=qdrant, | |
| interval_seconds=settings.QDRANT_KEEPALIVE_SECONDS, | |
| stop_event=keepalive_stop, | |
| ) | |
| ) | |
| app.state.qdrant_keepalive_stop = keepalive_stop | |
| app.state.qdrant_keepalive_task = keepalive_task | |
| logger.info("Startup complete") | |
| yield | |
| logger.info("Shutting down") | |
| app.state.qdrant_keepalive_stop.set() | |
| try: | |
| await asyncio.wait_for(app.state.qdrant_keepalive_task, timeout=2) | |
| except TimeoutError: | |
| app.state.qdrant_keepalive_task.cancel() | |
| except Exception: | |
| pass | |
| app.state.semantic_cache = None | |
| app.state.qdrant.close() | |
| # Only attempt to end an MLflow run when DagsHub tracking was enabled at startup. | |
| if settings.DAGSHUB_TOKEN: | |
| import mlflow | |
| if mlflow.active_run(): | |
| mlflow.end_run() | |
| def create_app() -> FastAPI: | |
| app = FastAPI( | |
| title="PersonaBot API", | |
| lifespan=lifespan, | |
| docs_url=None, | |
| redoc_url=None, | |
| openapi_url=None, | |
| ) | |
| app.state.limiter = limiter | |
| settings = get_settings() | |
| origins = [settings.ALLOWED_ORIGIN] | |
| if settings.ENVIRONMENT in ("local", "staging", "test"): | |
| origins.append("http://localhost:3000") | |
| app.docs_url = "/docs" | |
| app.redoc_url = "/redoc" | |
| app.openapi_url = "/openapi.json" | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["POST", "GET", "OPTIONS"], | |
| allow_headers=["Content-Type", "Authorization"], | |
| ) | |
| async def app_error_handler(request: Request, exc: AppError) -> JSONResponse: | |
| logger.error("AppError: %s", exc.message, extra={"context": exc.context}) | |
| return JSONResponse(status_code=400, content={"error": exc.message}) | |
| async def global_error_handler(request: Request, exc: Exception) -> JSONResponse: | |
| logger.error("Unhandled exception", exc_info=exc) | |
| return JSONResponse(status_code=500, content={"error": "Internal Server Error"}) | |
| app.add_exception_handler(RateLimitExceeded, custom_rate_limit_handler) | |
| app.include_router(health_router, tags=["Health"]) | |
| app.include_router(chat_router, prefix="/chat", tags=["Chat"]) | |
| app.include_router(transcribe_router, prefix="/transcribe", tags=["Transcribe"]) | |
| app.include_router(tts_router, prefix="/tts", tags=["TTS"]) | |
| app.include_router(feedback_router, prefix="/chat", tags=["Feedback"]) | |
| app.include_router(admin_router, prefix="/admin", tags=["Admin"]) | |
| return app | |
| app = create_app() | |