RAG_backend / app.py
Sahana31's picture
Update app.py
0010192 verified
import os
import uuid
import logging
import base64
from io import BytesIO
from datetime import datetime
from typing import List, Optional
import hashlib
from fastapi.responses import RedirectResponse
import chromadb
from chromadb.utils import embedding_functions
from fastapi import FastAPI, UploadFile, HTTPException, BackgroundTasks, Depends
from fastapi.middleware.cors import CORSMiddleware
from openai import AsyncOpenAI
from pdf2image import convert_from_bytes
from PIL import Image
from pydantic import BaseModel
from sqlalchemy import Column, String, Text, DateTime, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
# HF SPACES CONFIG
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("dwani_backend")
# Create persistent dirs
os.makedirs("/tmp/chroma_db", exist_ok=True)
os.makedirs("/tmp/files", exist_ok=True)
# REQUIRED SECRET
DWANI_API_BASE_URL = os.getenv("DWANI_API_BASE_URL")
if not DWANI_API_BASE_URL:
raise RuntimeError("🚨 Set DWANI_API_BASE_URL in Space Secrets!")
app = FastAPI(title="Dwani.ai RAG Backend v2.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# DATABASE
DATABASE_URL = "sqlite:////tmp/files.db"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(bind=engine)
Base = declarative_base()
class FileStatus:
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class FileRecord(Base):
__tablename__ = "files"
id = Column(String, primary_key=True)
filename = Column(String, index=True)
status = Column(String, default=FileStatus.PENDING)
extracted_text = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
Base.metadata.create_all(bind=engine)
def get_db():
db = SessionLocal()
try: yield db
finally: db.close()
# CHROMA VECTOR DB
chroma_client = chromadb.PersistentClient(path="/tmp/chroma_db")
collection = chroma_client.get_or_create_collection(name="documents")
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="BAAI/bge-small-en-v1.5"
)
# API SCHEMAS
class FileUploadResp(BaseModel):
file_id: str
filename: str
status: str
class FileInfo(BaseModel):
file_id: str
filename: str
status: str
class ChatRequest(BaseModel):
file_ids: List[str]
messages: List[dict]
# UTILITY FUNCTIONS
def encode_image(img: Image.Image) -> str:
buf = BytesIO()
img.save(buf, format="JPEG", quality=80)
return base64.b64encode(buf.getvalue()).decode()
async def extract_pdf_text(pdf_bytes: bytes) -> List[str]:
"""OCR PDF pages using vision model"""
client = AsyncOpenAI(api_key="http", base_url=DWANI_API_BASE_URL)
images = convert_from_bytes(pdf_bytes, fmt="png", dpi=200)
page_texts = []
for i, img in enumerate(images):
img_b64 = encode_image(img)
messages = [{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}},
{"type": "text", "text": "Extract all text from this page accurately."}
]
}]
resp = await client.chat.completions.create(
model="gemma3",
messages=messages,
temperature=0.1,
max_tokens=1500
)
page_texts.append(resp.choices[0].message.content.strip())
return page_texts
def create_chunks(page_texts: List[str], file_id: str, filename: str) -> List[dict]:
"""Create searchable chunks with metadata"""
chunks = []
for page_num, text in enumerate(page_texts, 1):
# Split into 500 char chunks
for i in range(0, len(text), 500):
chunk = text[i:i+500]
if len(chunk.strip()) > 50:
chunks.append({
"text": chunk.strip(),
"metadata": {
"file_id": file_id,
"filename": filename,
"page": page_num
}
})
return chunks
async def process_document(file_id: str, pdf_bytes: bytes, filename: str, db: Session):
"""Background document processing pipeline"""
record = db.query(FileRecord).filter(FileRecord.id == file_id).first()
if not record:
return
record.status = FileStatus.PROCESSING
db.commit()
try:
# 1. Extract text from PDF
page_texts = await extract_pdf_text(pdf_bytes)
full_text = "\n\n--- PAGE BREAK ---\n\n".join(page_texts)
# 2. Save extracted text
record.extracted_text = full_text
record.status = FileStatus.COMPLETED
db.commit()
# 3. Create embeddings
chunks = create_chunks(page_texts, file_id, filename)
if chunks:
docs = [c["text"] for c in chunks]
metas = [c["metadata"] for c in chunks]
ids = [f"{file_id}_{hashlib.md5(doc.encode()).hexdigest()}" for doc in docs]
# Clear old embeddings
collection.delete(where={"file_id": file_id})
# Add new embeddings
collection.add(
embeddings=embedding_fn(docs),
documents=docs,
metadatas=metas,
ids=ids
)
logger.info(f"✅ Embedded {len(docs)} chunks for {filename}")
except Exception as e:
record.status = FileStatus.FAILED
logger.error(f"❌ Processing failed {filename}: {e}")
finally:
record.status = record.status # Ensure status is saved
db.commit()
# API ENDPOINTS - MATCHES YOUR GRADIO FRONTEND
@app.post("/files/upload", response_model=FileUploadResp)
async def upload_file(
file: UploadFile,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db)
):
if not file.filename.lower().endswith('.pdf'):
raise HTTPException(400, detail="Only PDF files supported")
content = await file.read()
file_id = str(uuid.uuid4())
# Create record
record = FileRecord(
id=file_id,
filename=file.filename
)
db.add(record)
db.commit()
# Start background processing
background_tasks.add_task(
process_document, file_id, content, file.filename, db
)
return FileUploadResp(
file_id=file_id,
filename=file.filename,
status="pending"
)
@app.get("/files/{file_id}", response_model=FileInfo)
def get_file_status(file_id: str, db: Session = Depends(get_db)):
record = db.query(FileRecord).filter(FileRecord.id == file_id).first()
if not record:
raise HTTPException(404, "File not found")
return FileInfo(
file_id=record.id,
filename=record.filename,
status=record.status
)
@app.get("/files/")
def list_files(limit: int = 50, db: Session = Depends(get_db)):
files = db.query(FileRecord).order_by(FileRecord.created_at.desc()).limit(limit).all()
return [
{
"file_id": f.id,
"filename": f.filename,
"status": f.status,
"created_at": f.created_at.isoformat()
}
for f in files
]
@app.get("/",
summary="Redirect to Docs",
description="Redirects to the Swagger UI documentation.",
tags=["Utility"])
async def home():
return RedirectResponse(url="/docs")
@app.post("/chat-with-document")
async def chat_with_documents(request: ChatRequest, db: Session = Depends(get_db)):
# Validate files exist and are processed
if not request.file_ids:
raise HTTPException(400, "file_ids required")
records = db.query(FileRecord).filter(FileRecord.id.in_(request.file_ids)).all()
if len(records) != len(request.file_ids):
raise HTTPException(404, "Some files not found")
for record in records:
if record.status != FileStatus.COMPLETED:
raise HTTPException(400, f"File {record.filename} still processing")
# Get latest user question
user_messages = [m for m in request.messages if m.get("role") == "user"]
if not user_messages:
raise HTTPException(400, "No user question found")
question = user_messages[-1]["content"]
# Vector search
try:
results = collection.query(
query_embeddings=embedding_fn([question]),
n_results=6,
where={"file_id": {"$in": request.file_ids}},
include=["documents", "metadatas", "distances"]
)
except Exception as e:
logger.error(f"Vector search failed: {e}")
return {"answer": "Processing not complete yet", "sources": []}
if not results["documents"] or not results["documents"][0]:
return {"answer": "No relevant information found", "sources": []}
# Build context and sources
docs = results["documents"][0]
metas = results["metadatas"][0]
distances = results["distances"][0]
context_parts = []
sources = []
for i, (doc, meta, dist) in enumerate(zip(docs, metas, distances)):
context_parts.append(doc)
sources.append({
"filename": meta.get("filename", "Document"),
"page": meta.get("page", 1),
"excerpt": doc[:150] + "..." if len(doc) > 150 else doc,
"relevance_score": round(1 - dist, 3)
})
context = "\n\n".join(context_parts)
# Generate answer
client = AsyncOpenAI(api_key="http", base_url=DWANI_API_BASE_URL)
system_prompt = f"""You are a helpful assistant. Use ONLY the following context to answer.
Context from documents:
{context}
Answer concisely and cite sources when possible."""
messages = [
{"role": "system", "content": system_prompt},
*request.messages[-5:] # Last 5 messages for context
]
response = await client.chat.completions.create(
model="gemma3",
messages=messages,
temperature=0.3,
max_tokens=800
)
return {
"answer": response.choices[0].message.content.strip(),
"sources": sources[:4]
}
@app.get("/")
async def root():
return {
"status": "Dwani.ai RAG Backend ✅",
"endpoints": ["/files/upload", "/files/", "/files/{id}", "/chat-with-document"],
"docs": "/docs"
}
@app.get("/health")
async def health():
return {"status": "healthy"}