Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI endpoints for the RAG system with streaming support. | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from typing import Optional, List, AsyncGenerator | |
| import json | |
| import asyncio | |
| from src.app_hf import answer_question, add_urls, build_rag_chain, load_documents_from_crawler_cache | |
| from src.crawler import crawl_and_persist, async_crawl_and_persist | |
| app = FastAPI( | |
| title="Documentation RAG API", | |
| description="Retrieval-Augmented Generation API for technical documentation", | |
| version="1.0.0" | |
| ) | |
| class QueryRequest(BaseModel): | |
| question: str | |
| urls: Optional[List[str]] = None | |
| doc_dir: Optional[str] = "./my_docs" | |
| class CrawlRequest(BaseModel): | |
| base_url: str | |
| max_depth: Optional[int] = 3 | |
| max_pages: Optional[int] = 100 | |
| class QueryResponse(BaseModel): | |
| question: str | |
| answer: str | |
| source_urls: Optional[List[str]] = None | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return {"status": "ok", "service": "RAG API"} | |
| async def query(request: QueryRequest) -> QueryResponse: | |
| """ | |
| Query the RAG system. | |
| Returns a complete answer in one response. | |
| """ | |
| try: | |
| if request.urls: | |
| add_urls(request.urls) | |
| answer = answer_question(request.question, request.doc_dir, request.urls) | |
| return QueryResponse( | |
| question=request.question, | |
| answer=answer, | |
| source_urls=request.urls | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def query_stream(request: QueryRequest): | |
| """ | |
| Query the RAG system with streaming response. | |
| Streams tokens from the LLM as they're generated. | |
| """ | |
| async def generate(): | |
| try: | |
| if request.urls: | |
| add_urls(request.urls) | |
| # Get the RAG chain | |
| rag_chain = build_rag_chain(request.doc_dir, request.urls) | |
| # Stream response | |
| yield json.dumps({ | |
| "type": "start", | |
| "question": request.question | |
| }).encode() + b"\n" | |
| # Invoke with streaming | |
| response = rag_chain.invoke(request.question) | |
| answer_text = response.content if hasattr(response, "content") else str(response) | |
| # Stream the answer in chunks | |
| chunk_size = 10 | |
| for i in range(0, len(answer_text), chunk_size): | |
| chunk = answer_text[i:i + chunk_size] | |
| yield json.dumps({ | |
| "type": "token", | |
| "content": chunk | |
| }).encode() + b"\n" | |
| await asyncio.sleep(0.01) # Simulate streaming delay | |
| yield json.dumps({ | |
| "type": "end", | |
| "answer": answer_text | |
| }).encode() + b"\n" | |
| except Exception as e: | |
| yield json.dumps({ | |
| "type": "error", | |
| "error": str(e) | |
| }).encode() + b"\n" | |
| return StreamingResponse(generate(), media_type="application/x-ndjson") | |
| async def prepare_crawl(request: CrawlRequest): | |
| """ | |
| Endpoint to crawl a website and prepare documents for indexing. | |
| This is an async endpoint that returns crawl job status. | |
| """ | |
| try: | |
| documents = await async_crawl_and_persist(request.base_url, output_path="./crawler_docs.json", max_pages=request.max_pages) | |
| return { | |
| "status": "success", | |
| "documents_crawled": len(documents), | |
| "failed_urls": 0, | |
| "base_url": request.base_url, | |
| "message": f"Successfully crawled and persisted {len(documents)} pages from {request.base_url}" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def index_from_crawl(request: CrawlRequest): | |
| """ | |
| Crawl a website and automatically index its content. | |
| """ | |
| try: | |
| cached_docs = load_documents_from_crawler_cache() | |
| if not cached_docs: | |
| raise HTTPException( | |
| status_code=404, | |
| detail="No cached crawler documents found. Run /crawl/prepare first." | |
| ) | |
| build_rag_chain(request.doc_dir, urls=[]) | |
| return { | |
| "status": "success", | |
| "documents_indexed": len(cached_docs), | |
| "base_url": request.base_url, | |
| "message": f"Successfully indexed {len(cached_docs)} cached crawler pages" | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |