Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import logging | |
| import base64 | |
| from io import BytesIO | |
| from datetime import datetime | |
| from typing import List, Optional | |
| import hashlib | |
| from fastapi.responses import RedirectResponse | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| from fastapi import FastAPI, UploadFile, HTTPException, BackgroundTasks, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from openai import AsyncOpenAI | |
| from pdf2image import convert_from_bytes | |
| from PIL import Image | |
| from pydantic import BaseModel | |
| from sqlalchemy import Column, String, Text, DateTime, create_engine | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from sqlalchemy.orm import sessionmaker, Session | |
| # HF SPACES CONFIG | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("dwani_backend") | |
| # Create persistent dirs | |
| os.makedirs("/tmp/chroma_db", exist_ok=True) | |
| os.makedirs("/tmp/files", exist_ok=True) | |
| # REQUIRED SECRET | |
| DWANI_API_BASE_URL = os.getenv("DWANI_API_BASE_URL") | |
| if not DWANI_API_BASE_URL: | |
| raise RuntimeError("🚨 Set DWANI_API_BASE_URL in Space Secrets!") | |
| app = FastAPI(title="Dwani.ai RAG Backend v2.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # DATABASE | |
| DATABASE_URL = "sqlite:////tmp/files.db" | |
| engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) | |
| SessionLocal = sessionmaker(bind=engine) | |
| Base = declarative_base() | |
| class FileStatus: | |
| PENDING = "pending" | |
| PROCESSING = "processing" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| class FileRecord(Base): | |
| __tablename__ = "files" | |
| id = Column(String, primary_key=True) | |
| filename = Column(String, index=True) | |
| status = Column(String, default=FileStatus.PENDING) | |
| extracted_text = Column(Text) | |
| created_at = Column(DateTime, default=datetime.utcnow) | |
| Base.metadata.create_all(bind=engine) | |
| def get_db(): | |
| db = SessionLocal() | |
| try: yield db | |
| finally: db.close() | |
| # CHROMA VECTOR DB | |
| chroma_client = chromadb.PersistentClient(path="/tmp/chroma_db") | |
| collection = chroma_client.get_or_create_collection(name="documents") | |
| embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="BAAI/bge-small-en-v1.5" | |
| ) | |
| # API SCHEMAS | |
| class FileUploadResp(BaseModel): | |
| file_id: str | |
| filename: str | |
| status: str | |
| class FileInfo(BaseModel): | |
| file_id: str | |
| filename: str | |
| status: str | |
| class ChatRequest(BaseModel): | |
| file_ids: List[str] | |
| messages: List[dict] | |
| # UTILITY FUNCTIONS | |
| def encode_image(img: Image.Image) -> str: | |
| buf = BytesIO() | |
| img.save(buf, format="JPEG", quality=80) | |
| return base64.b64encode(buf.getvalue()).decode() | |
| async def extract_pdf_text(pdf_bytes: bytes) -> List[str]: | |
| """OCR PDF pages using vision model""" | |
| client = AsyncOpenAI(api_key="http", base_url=DWANI_API_BASE_URL) | |
| images = convert_from_bytes(pdf_bytes, fmt="png", dpi=200) | |
| page_texts = [] | |
| for i, img in enumerate(images): | |
| img_b64 = encode_image(img) | |
| messages = [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}}, | |
| {"type": "text", "text": "Extract all text from this page accurately."} | |
| ] | |
| }] | |
| resp = await client.chat.completions.create( | |
| model="gemma3", | |
| messages=messages, | |
| temperature=0.1, | |
| max_tokens=1500 | |
| ) | |
| page_texts.append(resp.choices[0].message.content.strip()) | |
| return page_texts | |
| def create_chunks(page_texts: List[str], file_id: str, filename: str) -> List[dict]: | |
| """Create searchable chunks with metadata""" | |
| chunks = [] | |
| for page_num, text in enumerate(page_texts, 1): | |
| # Split into 500 char chunks | |
| for i in range(0, len(text), 500): | |
| chunk = text[i:i+500] | |
| if len(chunk.strip()) > 50: | |
| chunks.append({ | |
| "text": chunk.strip(), | |
| "metadata": { | |
| "file_id": file_id, | |
| "filename": filename, | |
| "page": page_num | |
| } | |
| }) | |
| return chunks | |
| async def process_document(file_id: str, pdf_bytes: bytes, filename: str, db: Session): | |
| """Background document processing pipeline""" | |
| record = db.query(FileRecord).filter(FileRecord.id == file_id).first() | |
| if not record: | |
| return | |
| record.status = FileStatus.PROCESSING | |
| db.commit() | |
| try: | |
| # 1. Extract text from PDF | |
| page_texts = await extract_pdf_text(pdf_bytes) | |
| full_text = "\n\n--- PAGE BREAK ---\n\n".join(page_texts) | |
| # 2. Save extracted text | |
| record.extracted_text = full_text | |
| record.status = FileStatus.COMPLETED | |
| db.commit() | |
| # 3. Create embeddings | |
| chunks = create_chunks(page_texts, file_id, filename) | |
| if chunks: | |
| docs = [c["text"] for c in chunks] | |
| metas = [c["metadata"] for c in chunks] | |
| ids = [f"{file_id}_{hashlib.md5(doc.encode()).hexdigest()}" for doc in docs] | |
| # Clear old embeddings | |
| collection.delete(where={"file_id": file_id}) | |
| # Add new embeddings | |
| collection.add( | |
| embeddings=embedding_fn(docs), | |
| documents=docs, | |
| metadatas=metas, | |
| ids=ids | |
| ) | |
| logger.info(f"✅ Embedded {len(docs)} chunks for {filename}") | |
| except Exception as e: | |
| record.status = FileStatus.FAILED | |
| logger.error(f"❌ Processing failed {filename}: {e}") | |
| finally: | |
| record.status = record.status # Ensure status is saved | |
| db.commit() | |
| # API ENDPOINTS - MATCHES YOUR GRADIO FRONTEND | |
| async def upload_file( | |
| file: UploadFile, | |
| background_tasks: BackgroundTasks, | |
| db: Session = Depends(get_db) | |
| ): | |
| if not file.filename.lower().endswith('.pdf'): | |
| raise HTTPException(400, detail="Only PDF files supported") | |
| content = await file.read() | |
| file_id = str(uuid.uuid4()) | |
| # Create record | |
| record = FileRecord( | |
| id=file_id, | |
| filename=file.filename | |
| ) | |
| db.add(record) | |
| db.commit() | |
| # Start background processing | |
| background_tasks.add_task( | |
| process_document, file_id, content, file.filename, db | |
| ) | |
| return FileUploadResp( | |
| file_id=file_id, | |
| filename=file.filename, | |
| status="pending" | |
| ) | |
| def get_file_status(file_id: str, db: Session = Depends(get_db)): | |
| record = db.query(FileRecord).filter(FileRecord.id == file_id).first() | |
| if not record: | |
| raise HTTPException(404, "File not found") | |
| return FileInfo( | |
| file_id=record.id, | |
| filename=record.filename, | |
| status=record.status | |
| ) | |
| def list_files(limit: int = 50, db: Session = Depends(get_db)): | |
| files = db.query(FileRecord).order_by(FileRecord.created_at.desc()).limit(limit).all() | |
| return [ | |
| { | |
| "file_id": f.id, | |
| "filename": f.filename, | |
| "status": f.status, | |
| "created_at": f.created_at.isoformat() | |
| } | |
| for f in files | |
| ] | |
| async def home(): | |
| return RedirectResponse(url="/docs") | |
| async def chat_with_documents(request: ChatRequest, db: Session = Depends(get_db)): | |
| # Validate files exist and are processed | |
| if not request.file_ids: | |
| raise HTTPException(400, "file_ids required") | |
| records = db.query(FileRecord).filter(FileRecord.id.in_(request.file_ids)).all() | |
| if len(records) != len(request.file_ids): | |
| raise HTTPException(404, "Some files not found") | |
| for record in records: | |
| if record.status != FileStatus.COMPLETED: | |
| raise HTTPException(400, f"File {record.filename} still processing") | |
| # Get latest user question | |
| user_messages = [m for m in request.messages if m.get("role") == "user"] | |
| if not user_messages: | |
| raise HTTPException(400, "No user question found") | |
| question = user_messages[-1]["content"] | |
| # Vector search | |
| try: | |
| results = collection.query( | |
| query_embeddings=embedding_fn([question]), | |
| n_results=6, | |
| where={"file_id": {"$in": request.file_ids}}, | |
| include=["documents", "metadatas", "distances"] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Vector search failed: {e}") | |
| return {"answer": "Processing not complete yet", "sources": []} | |
| if not results["documents"] or not results["documents"][0]: | |
| return {"answer": "No relevant information found", "sources": []} | |
| # Build context and sources | |
| docs = results["documents"][0] | |
| metas = results["metadatas"][0] | |
| distances = results["distances"][0] | |
| context_parts = [] | |
| sources = [] | |
| for i, (doc, meta, dist) in enumerate(zip(docs, metas, distances)): | |
| context_parts.append(doc) | |
| sources.append({ | |
| "filename": meta.get("filename", "Document"), | |
| "page": meta.get("page", 1), | |
| "excerpt": doc[:150] + "..." if len(doc) > 150 else doc, | |
| "relevance_score": round(1 - dist, 3) | |
| }) | |
| context = "\n\n".join(context_parts) | |
| # Generate answer | |
| client = AsyncOpenAI(api_key="http", base_url=DWANI_API_BASE_URL) | |
| system_prompt = f"""You are a helpful assistant. Use ONLY the following context to answer. | |
| Context from documents: | |
| {context} | |
| Answer concisely and cite sources when possible.""" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| *request.messages[-5:] # Last 5 messages for context | |
| ] | |
| response = await client.chat.completions.create( | |
| model="gemma3", | |
| messages=messages, | |
| temperature=0.3, | |
| max_tokens=800 | |
| ) | |
| return { | |
| "answer": response.choices[0].message.content.strip(), | |
| "sources": sources[:4] | |
| } | |
| async def root(): | |
| return { | |
| "status": "Dwani.ai RAG Backend ✅", | |
| "endpoints": ["/files/upload", "/files/", "/files/{id}", "/chat-with-document"], | |
| "docs": "/docs" | |
| } | |
| async def health(): | |
| return {"status": "healthy"} | |