import os import uuid import logging import base64 from io import BytesIO from datetime import datetime from typing import List, Optional import hashlib from fastapi.responses import RedirectResponse import chromadb from chromadb.utils import embedding_functions from fastapi import FastAPI, UploadFile, HTTPException, BackgroundTasks, Depends from fastapi.middleware.cors import CORSMiddleware from openai import AsyncOpenAI from pdf2image import convert_from_bytes from PIL import Image from pydantic import BaseModel from sqlalchemy import Column, String, Text, DateTime, create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, Session # HF SPACES CONFIG logging.basicConfig(level=logging.INFO) logger = logging.getLogger("dwani_backend") # Create persistent dirs os.makedirs("/tmp/chroma_db", exist_ok=True) os.makedirs("/tmp/files", exist_ok=True) # REQUIRED SECRET DWANI_API_BASE_URL = os.getenv("DWANI_API_BASE_URL") if not DWANI_API_BASE_URL: raise RuntimeError("🚨 Set DWANI_API_BASE_URL in Space Secrets!") app = FastAPI(title="Dwani.ai RAG Backend v2.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # DATABASE DATABASE_URL = "sqlite:////tmp/files.db" engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) SessionLocal = sessionmaker(bind=engine) Base = declarative_base() class FileStatus: PENDING = "pending" PROCESSING = "processing" COMPLETED = "completed" FAILED = "failed" class FileRecord(Base): __tablename__ = "files" id = Column(String, primary_key=True) filename = Column(String, index=True) status = Column(String, default=FileStatus.PENDING) extracted_text = Column(Text) created_at = Column(DateTime, default=datetime.utcnow) Base.metadata.create_all(bind=engine) def get_db(): db = SessionLocal() try: yield db finally: db.close() # CHROMA VECTOR DB chroma_client = chromadb.PersistentClient(path="/tmp/chroma_db") collection = chroma_client.get_or_create_collection(name="documents") embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction( model_name="BAAI/bge-small-en-v1.5" ) # API SCHEMAS class FileUploadResp(BaseModel): file_id: str filename: str status: str class FileInfo(BaseModel): file_id: str filename: str status: str class ChatRequest(BaseModel): file_ids: List[str] messages: List[dict] # UTILITY FUNCTIONS def encode_image(img: Image.Image) -> str: buf = BytesIO() img.save(buf, format="JPEG", quality=80) return base64.b64encode(buf.getvalue()).decode() async def extract_pdf_text(pdf_bytes: bytes) -> List[str]: """OCR PDF pages using vision model""" client = AsyncOpenAI(api_key="http", base_url=DWANI_API_BASE_URL) images = convert_from_bytes(pdf_bytes, fmt="png", dpi=200) page_texts = [] for i, img in enumerate(images): img_b64 = encode_image(img) messages = [{ "role": "user", "content": [ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}}, {"type": "text", "text": "Extract all text from this page accurately."} ] }] resp = await client.chat.completions.create( model="gemma3", messages=messages, temperature=0.1, max_tokens=1500 ) page_texts.append(resp.choices[0].message.content.strip()) return page_texts def create_chunks(page_texts: List[str], file_id: str, filename: str) -> List[dict]: """Create searchable chunks with metadata""" chunks = [] for page_num, text in enumerate(page_texts, 1): # Split into 500 char chunks for i in range(0, len(text), 500): chunk = text[i:i+500] if len(chunk.strip()) > 50: chunks.append({ "text": chunk.strip(), "metadata": { "file_id": file_id, "filename": filename, "page": page_num } }) return chunks async def process_document(file_id: str, pdf_bytes: bytes, filename: str, db: Session): """Background document processing pipeline""" record = db.query(FileRecord).filter(FileRecord.id == file_id).first() if not record: return record.status = FileStatus.PROCESSING db.commit() try: # 1. Extract text from PDF page_texts = await extract_pdf_text(pdf_bytes) full_text = "\n\n--- PAGE BREAK ---\n\n".join(page_texts) # 2. Save extracted text record.extracted_text = full_text record.status = FileStatus.COMPLETED db.commit() # 3. Create embeddings chunks = create_chunks(page_texts, file_id, filename) if chunks: docs = [c["text"] for c in chunks] metas = [c["metadata"] for c in chunks] ids = [f"{file_id}_{hashlib.md5(doc.encode()).hexdigest()}" for doc in docs] # Clear old embeddings collection.delete(where={"file_id": file_id}) # Add new embeddings collection.add( embeddings=embedding_fn(docs), documents=docs, metadatas=metas, ids=ids ) logger.info(f"✅ Embedded {len(docs)} chunks for {filename}") except Exception as e: record.status = FileStatus.FAILED logger.error(f"❌ Processing failed {filename}: {e}") finally: record.status = record.status # Ensure status is saved db.commit() # API ENDPOINTS - MATCHES YOUR GRADIO FRONTEND @app.post("/files/upload", response_model=FileUploadResp) async def upload_file( file: UploadFile, background_tasks: BackgroundTasks, db: Session = Depends(get_db) ): if not file.filename.lower().endswith('.pdf'): raise HTTPException(400, detail="Only PDF files supported") content = await file.read() file_id = str(uuid.uuid4()) # Create record record = FileRecord( id=file_id, filename=file.filename ) db.add(record) db.commit() # Start background processing background_tasks.add_task( process_document, file_id, content, file.filename, db ) return FileUploadResp( file_id=file_id, filename=file.filename, status="pending" ) @app.get("/files/{file_id}", response_model=FileInfo) def get_file_status(file_id: str, db: Session = Depends(get_db)): record = db.query(FileRecord).filter(FileRecord.id == file_id).first() if not record: raise HTTPException(404, "File not found") return FileInfo( file_id=record.id, filename=record.filename, status=record.status ) @app.get("/files/") def list_files(limit: int = 50, db: Session = Depends(get_db)): files = db.query(FileRecord).order_by(FileRecord.created_at.desc()).limit(limit).all() return [ { "file_id": f.id, "filename": f.filename, "status": f.status, "created_at": f.created_at.isoformat() } for f in files ] @app.get("/", summary="Redirect to Docs", description="Redirects to the Swagger UI documentation.", tags=["Utility"]) async def home(): return RedirectResponse(url="/docs") @app.post("/chat-with-document") async def chat_with_documents(request: ChatRequest, db: Session = Depends(get_db)): # Validate files exist and are processed if not request.file_ids: raise HTTPException(400, "file_ids required") records = db.query(FileRecord).filter(FileRecord.id.in_(request.file_ids)).all() if len(records) != len(request.file_ids): raise HTTPException(404, "Some files not found") for record in records: if record.status != FileStatus.COMPLETED: raise HTTPException(400, f"File {record.filename} still processing") # Get latest user question user_messages = [m for m in request.messages if m.get("role") == "user"] if not user_messages: raise HTTPException(400, "No user question found") question = user_messages[-1]["content"] # Vector search try: results = collection.query( query_embeddings=embedding_fn([question]), n_results=6, where={"file_id": {"$in": request.file_ids}}, include=["documents", "metadatas", "distances"] ) except Exception as e: logger.error(f"Vector search failed: {e}") return {"answer": "Processing not complete yet", "sources": []} if not results["documents"] or not results["documents"][0]: return {"answer": "No relevant information found", "sources": []} # Build context and sources docs = results["documents"][0] metas = results["metadatas"][0] distances = results["distances"][0] context_parts = [] sources = [] for i, (doc, meta, dist) in enumerate(zip(docs, metas, distances)): context_parts.append(doc) sources.append({ "filename": meta.get("filename", "Document"), "page": meta.get("page", 1), "excerpt": doc[:150] + "..." if len(doc) > 150 else doc, "relevance_score": round(1 - dist, 3) }) context = "\n\n".join(context_parts) # Generate answer client = AsyncOpenAI(api_key="http", base_url=DWANI_API_BASE_URL) system_prompt = f"""You are a helpful assistant. Use ONLY the following context to answer. Context from documents: {context} Answer concisely and cite sources when possible.""" messages = [ {"role": "system", "content": system_prompt}, *request.messages[-5:] # Last 5 messages for context ] response = await client.chat.completions.create( model="gemma3", messages=messages, temperature=0.3, max_tokens=800 ) return { "answer": response.choices[0].message.content.strip(), "sources": sources[:4] } @app.get("/") async def root(): return { "status": "Dwani.ai RAG Backend ✅", "endpoints": ["/files/upload", "/files/", "/files/{id}", "/chat-with-document"], "docs": "/docs" } @app.get("/health") async def health(): return {"status": "healthy"}