Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import importlib.util | |
| import logging | |
| from pathlib import Path | |
| from fastapi import FastAPI, File, Request, UploadFile, BackgroundTasks | |
| from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.exceptions import RequestValidationError | |
| from app.config import Settings, load_settings | |
| from app.models.schemas import RAGResponse, RetrievedContext | |
| from app.models.types import Language, compute_doc_id | |
| from app.pipelines.ingest import ingest_pdf | |
| from app.pipelines.query import answer_question | |
| from app.services.rate_limiter import IpRateLimiter | |
| from app.services.vector_store import VectorStore | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| LANG_MAP = {"auto": None, "en": Language.EN, "id": Language.ID} | |
| STATIC_DIR = Path(__file__).resolve().parent.parent / "static" | |
| HAS_MULTIPART = importlib.util.find_spec("multipart") is not None | |
| def _get_runtime_dependencies(app: FastAPI) -> tuple[Settings, VectorStore]: | |
| settings = getattr(app.state, "settings", None) | |
| vector_store = getattr(app.state, "vector_store", None) | |
| if settings is None: | |
| settings = load_settings() | |
| app.state.settings = settings | |
| if vector_store is None: | |
| vector_store = VectorStore(settings) | |
| vector_store.ensure_index() | |
| app.state.vector_store = vector_store | |
| return settings, vector_store | |
| def _get_rate_limiter(app: FastAPI) -> IpRateLimiter: | |
| limiter = getattr(app.state, "rate_limiter", None) | |
| if limiter is not None: | |
| return limiter | |
| settings = getattr(app.state, "settings", None) | |
| if settings is None: | |
| settings = load_settings() | |
| app.state.settings = settings | |
| limiter = IpRateLimiter( | |
| max_requests=settings.rate_limit_requests_per_hour, | |
| window_seconds=settings.rate_limit_window_seconds, | |
| ) | |
| app.state.rate_limiter = limiter | |
| return limiter | |
| def _resolve_client_ip(request: Request) -> str: | |
| forwarded_for = request.headers.get("x-forwarded-for", "") | |
| if forwarded_for: | |
| first_ip = forwarded_for.split(",", 1)[0].strip() | |
| if first_ip: | |
| return first_ip | |
| real_ip = request.headers.get("x-real-ip", "").strip() | |
| if real_ip: | |
| return real_ip | |
| client_host = request.client.host if request.client else "" | |
| return client_host.strip() or "unknown" | |
| def _build_evidence_payload(evidence: list[RetrievedContext]) -> list[dict[str, object]]: | |
| return [ | |
| { | |
| "id": ctx.evidence_id, | |
| "source": ctx.source_filename, | |
| "page": ctx.page, | |
| "snippet": ctx.snippet, | |
| "score": round(ctx.score, 3), | |
| "rerank_score": round(ctx.rerank_score, 3) | |
| if ctx.rerank_score is not None | |
| else None, | |
| } | |
| for ctx in evidence | |
| ] | |
| def _build_citations_payload(evidence: list[RetrievedContext]) -> list[dict[str, object]]: | |
| seen: set[tuple[str, int]] = set() | |
| citations: list[dict[str, object]] = [] | |
| for ctx in evidence: | |
| key = (ctx.source_filename, ctx.page) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| citations.append( | |
| { | |
| "source": ctx.source_filename, | |
| "page": ctx.page, | |
| "score": round(ctx.score, 3), | |
| } | |
| ) | |
| return citations | |
| def _stats_value(stats: object, key: str, default: object = None) -> object: | |
| if isinstance(stats, dict): | |
| return stats.get(key, default) | |
| return getattr(stats, key, default) | |
| def create_app( | |
| settings: Settings | None = None, | |
| vector_store: VectorStore | None = None, | |
| rate_limiter: IpRateLimiter | None = None, | |
| ) -> FastAPI: | |
| app = FastAPI(title="Gapura RAG API") | |
| async def validation_exception_handler(request: Request, exc: RequestValidationError): | |
| # We manually log the detail because the user only sees the 422 in their terminal | |
| logger.error(f"Validation error for {request.url}: {exc.errors()}") | |
| return JSONResponse( | |
| status_code=422, | |
| content={"detail": exc.errors(), "body": str(exc.body)}, | |
| ) | |
| app.state.settings = settings | |
| app.state.vector_store = vector_store | |
| app.state.rate_limiter = rate_limiter | |
| app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") | |
| app.mount( | |
| "/assets", | |
| StaticFiles(directory=str(Path(__file__).resolve().parent.parent)), | |
| name="assets", | |
| ) | |
| async def index() -> HTMLResponse: | |
| html_path = STATIC_DIR / "index.html" | |
| return HTMLResponse(content=html_path.read_text(encoding="utf-8")) | |
| if HAS_MULTIPART: | |
| from fastapi import BackgroundTasks | |
| def _bg_ingest_pdf( | |
| file_bytes: bytes, | |
| filename: str, | |
| settings: Settings, | |
| vector_store: VectorStore, | |
| replace_existing: bool = False, | |
| ): | |
| try: | |
| logger.info("Starting background ingestion for %s", filename) | |
| result = ingest_pdf( | |
| file_bytes, | |
| filename, | |
| settings, | |
| vector_store, | |
| replace_existing=replace_existing, | |
| ) | |
| logger.info( | |
| "Background ingestion completed for %s: %d pages, %d chunks", | |
| filename, | |
| result.total_pages, | |
| result.total_chunks, | |
| ) | |
| except Exception as exc: | |
| logger.error("Background ingestion failed for %s: %s", filename, exc) | |
| async def upload_pdf( | |
| request: Request, | |
| background_tasks: BackgroundTasks, | |
| file: UploadFile = File(...), | |
| ): | |
| try: | |
| settings, vector_store = _get_runtime_dependencies(request.app) | |
| if not file.filename or not file.filename.lower().endswith(".pdf"): | |
| return JSONResponse({"error": "Only PDF files are accepted"}, status_code=400) | |
| file_bytes = await file.read() | |
| size_mb = len(file_bytes) / (1024 * 1024) | |
| doc_id = compute_doc_id(file_bytes) | |
| replace_existing = ( | |
| request.query_params.get("replace", "").strip().lower() == "true" | |
| ) | |
| if size_mb > settings.max_pdf_size_mb: | |
| return JSONResponse( | |
| {"error": f"File exceeds {settings.max_pdf_size_mb}MB limit"}, | |
| status_code=400, | |
| ) | |
| # Check if document already exists (fast) | |
| if vector_store.doc_exists(doc_id) and not replace_existing: | |
| return { | |
| "doc_id": str(doc_id), | |
| "filename": file.filename, | |
| "pages": 0, | |
| "chunks": 0, | |
| "skipped": True, | |
| } | |
| # Start background processing | |
| background_tasks.add_task( | |
| _bg_ingest_pdf, | |
| file_bytes, | |
| file.filename, | |
| settings, | |
| vector_store, | |
| replace_existing, | |
| ) | |
| # Return immediate response. Page/chunk counts won't be known yet, | |
| # so we return a placeholder or 0. | |
| return { | |
| "doc_id": str(doc_id), | |
| "filename": file.filename, | |
| "pages": -1, # UI special value for "Processing..." | |
| "chunks": -1, | |
| "skipped": False, | |
| "replace_existing": replace_existing, | |
| } | |
| except Exception as exc: | |
| logger.exception("Upload error for %s", file.filename) | |
| return JSONResponse({"error": str(exc)}, status_code=500) | |
| else: | |
| async def upload_pdf_unavailable(): | |
| return JSONResponse( | |
| {"error": "python-multipart is not installed on this deployment."}, | |
| status_code=500, | |
| ) | |
| async def chat(request: Request): | |
| body = await request.json() | |
| question = body.get("question", "").strip() | |
| lang = body.get("language", "auto") | |
| history = body.get("history", []) | |
| sources = body.get("sources") | |
| if not question: | |
| return JSONResponse({"error": "Question cannot be empty"}, status_code=400) | |
| try: | |
| limiter = _get_rate_limiter(request.app) | |
| except Exception as exc: | |
| logger.exception("Rate limiter bootstrap failed during chat") | |
| return JSONResponse({"error": str(exc)}, status_code=500) | |
| settings = getattr(request.app.state, "settings", None) | |
| max_requests = ( | |
| settings.rate_limit_requests_per_hour | |
| if settings is not None | |
| else 5 | |
| ) | |
| client_ip = _resolve_client_ip(request) | |
| limit_decision = limiter.check(client_ip) | |
| if not limit_decision.allowed: | |
| retry_after = limit_decision.retry_after_seconds | |
| return JSONResponse( | |
| { | |
| "error": ( | |
| "Rate limit exceeded. " | |
| f"You can ask up to {max_requests} questions per hour from the same IP address." | |
| ), | |
| "retry_after_seconds": retry_after, | |
| }, | |
| status_code=429, | |
| headers={"Retry-After": str(retry_after)}, | |
| ) | |
| try: | |
| settings, vector_store = _get_runtime_dependencies(request.app) | |
| except Exception as exc: | |
| logger.exception("Runtime bootstrap failed during chat") | |
| return JSONResponse({"error": str(exc)}, status_code=500) | |
| lang_override = LANG_MAP.get(lang) | |
| def event_stream(): | |
| try: | |
| response: RAGResponse = answer_question( | |
| question=question, | |
| settings=settings, | |
| vector_store=vector_store, | |
| language_override=lang_override, | |
| history=history or None, | |
| sources=sources, | |
| ) | |
| if response.answer: | |
| yield ( | |
| "data: " | |
| + json.dumps({"type": "token", "content": response.answer}) | |
| + "\n\n" | |
| ) | |
| yield ( | |
| "data: " | |
| + json.dumps( | |
| { | |
| "type": "done", | |
| "answer": response.answer, | |
| "language": response.detected_language.value, | |
| "grounding_status": response.grounding_status.value, | |
| "supplement_used": response.supplement_used, | |
| "evidence": _build_evidence_payload(response.evidence), | |
| "citations": _build_citations_payload( | |
| response.citations or response.evidence | |
| ), | |
| } | |
| ) | |
| + "\n\n" | |
| ) | |
| except Exception as exc: | |
| logger.exception("Chat error") | |
| yield f"data: {json.dumps({'type': 'error', 'content': str(exc)})}\n\n" | |
| return StreamingResponse(event_stream(), media_type="text/event-stream") | |
| async def stats(request: Request): | |
| try: | |
| settings, vector_store = _get_runtime_dependencies(request.app) | |
| raw = vector_store.get_stats() | |
| binding = vector_store.get_index_binding() | |
| return { | |
| "index": vector_store.index_name, | |
| "configured_index": binding["configured_index"], | |
| "active_index": binding["active_index"], | |
| "total_vectors": _stats_value(raw, "total_vector_count", 0), | |
| "embedding_model": settings.embedding_model, | |
| "embedding_dim": binding["embedding_dim"], | |
| "index_dimension": binding["index_dimension"], | |
| "pinecone_metric": binding["metric"], | |
| "llm_model": settings.llm_model, | |
| } | |
| except Exception as exc: | |
| return JSONResponse({"error": str(exc)}, status_code=500) | |
| return app | |
| app = create_app() | |