Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| load_dotenv(Path(__file__).resolve().parent.parent / ".env") | |
| import asyncio | |
| import json | |
| import os | |
| import re | |
| from typing import Literal | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from .ingestion import read_upload | |
| from .parser import parse_file | |
| from .rag import RAGEngine | |
| app = FastAPI(title="Document Ingestion API", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| rag_engine = RAGEngine(embed_provider=os.environ.get("EMBED_PROVIDER", "gemini")) | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| class QueryRequest(BaseModel): | |
| query: str | |
| history: list[ChatMessage] = [] | |
| top_k: int = 3 | |
| answer_mode: Literal["balanced", "strict_grounded"] = "balanced" | |
| class IngestTextRequest(BaseModel): | |
| text: str | |
| chunk_size: int = 180 | |
| chunk_overlap: int = 40 | |
| def _compute_recommendations(text: str, num_chunks: int) -> dict: | |
| sentences = [s.strip() for s in re.split(r"[.!?\n]+", text) if s.strip()] | |
| avg_len = sum(len(s) for s in sentences) / max(len(sentences), 1) | |
| text_len = len(text) | |
| if text_len > 5_000_000: | |
| rec_cs = max(500, min(1000, int(avg_len * 8 / 10) * 10)) | |
| elif text_len > 1_000_000: | |
| rec_cs = max(300, min(800, int(avg_len * 6 / 10) * 10)) | |
| else: | |
| rec_cs = max(50, min(1000, int(avg_len * 4 / 10) * 10)) | |
| rec_co = max(0, min(200, int(rec_cs * 0.2 / 5) * 5)) | |
| if num_chunks <= 20: | |
| rec_tk = 3 | |
| elif num_chunks <= 100: | |
| rec_tk = 5 | |
| elif num_chunks <= 500: | |
| rec_tk = 7 | |
| else: | |
| rec_tk = 10 | |
| return {"chunk_size": rec_cs, "chunk_overlap": rec_co, "top_k": rec_tk} | |
| async def upload_file( | |
| file: UploadFile = File(...), | |
| chunk_size: int = Form(180), | |
| chunk_overlap: int = Form(40), | |
| ): | |
| content, metadata = await read_upload(file) | |
| try: | |
| text = parse_file(metadata.filename, content) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception: | |
| raise HTTPException(status_code=500, detail="Failed to process the file.") | |
| rag_engine.chunk_size = chunk_size | |
| rag_engine.chunk_overlap = chunk_overlap | |
| num_chunks = rag_engine.start_ingest(text) | |
| # Run CPU-bound embedding in a background thread | |
| loop = asyncio.get_running_loop() | |
| loop.run_in_executor(None, rag_engine._do_embed) | |
| recs = _compute_recommendations(text, num_chunks) | |
| return { | |
| "text": text, | |
| "metadata": { | |
| "filename": metadata.filename, | |
| "file_type": metadata.file_type, | |
| "size_bytes": metadata.size_bytes, | |
| "chunk_size": chunk_size, | |
| "chunk_overlap": chunk_overlap, | |
| }, | |
| "num_chunks": num_chunks, | |
| "recommended_chunk_size": recs["chunk_size"], | |
| "recommended_chunk_overlap": recs["chunk_overlap"], | |
| "recommended_top_k": recs["top_k"], | |
| } | |
| async def ingest_text(body: IngestTextRequest): | |
| if not body.text.strip(): | |
| raise HTTPException(status_code=400, detail="Text is empty.") | |
| rag_engine.chunk_size = body.chunk_size | |
| rag_engine.chunk_overlap = body.chunk_overlap | |
| num_chunks = rag_engine.start_ingest(body.text) | |
| loop = asyncio.get_running_loop() | |
| loop.run_in_executor(None, rag_engine._do_embed) | |
| return {"num_chunks": num_chunks} | |
| async def ingest_status(): | |
| return rag_engine.ingest_progress | |
| async def config(): | |
| return { | |
| "embed_provider": rag_engine.embed_provider, | |
| "retrieval_backend": "knowledge_graph", | |
| } | |
| async def vector_store_status(): | |
| return rag_engine.vector_store_status() | |
| async def graph_store_status(): | |
| return rag_engine.vector_store_status() | |
| async def query_document(body: QueryRequest): | |
| if not rag_engine.chunks: | |
| raise HTTPException(status_code=400, detail="No document has been ingested yet. Upload a file first.") | |
| if not rag_engine.is_ready: | |
| raise HTTPException(status_code=409, detail="Document is still being embedded. Please wait.") | |
| if not body.query.strip(): | |
| raise HTTPException(status_code=400, detail="Query cannot be empty.") | |
| history = [{"role": m.role, "content": m.content} for m in body.history] | |
| rag_engine.top_k = body.top_k | |
| result = rag_engine.answer(body.query, history=history, answer_mode=body.answer_mode) | |
| return result | |
| async def query_stream(body: QueryRequest): | |
| if not rag_engine.chunks: | |
| raise HTTPException(status_code=400, detail="No document has been ingested yet. Upload a file first.") | |
| if not rag_engine.is_ready: | |
| raise HTTPException(status_code=409, detail="Document is still being embedded. Please wait.") | |
| if not body.query.strip(): | |
| raise HTTPException(status_code=400, detail="Query cannot be empty.") | |
| history = [{"role": m.role, "content": m.content} for m in body.history] | |
| rag_engine.top_k = body.top_k | |
| def event_stream(): | |
| for event in rag_engine.stream_answer(body.query, history=history, answer_mode=body.answer_mode): | |
| yield f"data: {json.dumps(event)}\n\n" | |
| return StreamingResponse(event_stream(), media_type="text/event-stream") | |
| async def chunk_context(index: int, window: int = 1): | |
| """Return the chunk at `index` plus up to `window` neighbours on each side.""" | |
| if not rag_engine.chunks: | |
| raise HTTPException(status_code=400, detail="No document ingested.") | |
| total = len(rag_engine.chunks) | |
| if index < 0 or index >= total: | |
| raise HTTPException(status_code=404, detail=f"Chunk index {index} out of range.") | |
| start = max(0, index - window) | |
| end = min(total, index + window + 1) | |
| return { | |
| "index": index, | |
| "chunks": [ | |
| {"index": i, "text": rag_engine.chunks[i], "is_target": i == index} | |
| for i in range(start, end) | |
| ] | |
| } | |
| async def list_chunks(offset: int = 0, limit: int = 50, search: str = ""): | |
| """Return a paginated, optionally filtered list of all ingested chunks.""" | |
| if not rag_engine.chunks: | |
| raise HTTPException(status_code=400, detail="No document ingested.") | |
| source = rag_engine.chunks | |
| if search: | |
| q = search.lower() | |
| indices = [i for i, c in enumerate(source) if q in c.lower()] | |
| else: | |
| indices = list(range(len(source))) | |
| page = indices[offset: offset + limit] | |
| return { | |
| "total": len(indices), | |
| "offset": offset, | |
| "limit": limit, | |
| "chunks": [{"index": i, "text": source[i]} for i in page], | |
| } | |