MohitGupta41
FastAPI RAG backend (Docker)
f7c12a3
# rag.py
import uuid
import time
from typing import List, Dict, Any, Tuple
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
# PDF extraction
import fitz # pymupdf
# LLM (Qwen)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# -----------------------------
# Globals (MVP)
# -----------------------------
EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2")
QWEN_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
QWEN_MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
)
GENERATOR = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
# session_id -> {chunks, index, created_at, docs}
SESSIONS: Dict[str, Dict[str, Any]] = {}
# -----------------------------
# Helpers
# -----------------------------
def extract_text_from_pdf(pdf_bytes: bytes) -> str:
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
pages = []
for page in doc:
pages.append(page.get_text("text"))
return "\n".join(pages).strip()
def chunk_text(text: str, chunk_size_words: int = 350, overlap_words: int = 60) -> List[str]:
words = text.split()
chunks: List[str] = []
step = max(1, chunk_size_words - overlap_words)
for i in range(0, len(words), step):
chunk = words[i:i + chunk_size_words]
if chunk:
chunks.append(" ".join(chunk))
return chunks
def build_faiss_index(vectors: np.ndarray) -> faiss.Index:
vectors = vectors.astype("float32")
dim = vectors.shape[1]
index = faiss.IndexFlatIP(dim)
faiss.normalize_L2(vectors)
index.add(vectors)
return index
def retrieve_top_k(
query: str,
chunks: List[str],
index: faiss.Index,
k: int = 3
) -> List[Tuple[int, float, str]]:
q = EMBEDDER.encode([query], convert_to_numpy=True).astype("float32")
faiss.normalize_L2(q)
scores, ids = index.search(q, k)
results: List[Tuple[int, float, str]] = []
for rank, idx in enumerate(ids[0]):
if idx == -1:
continue
results.append((int(idx), float(scores[0][rank]), chunks[int(idx)]))
return results
def _build_qwen_prompt(question: str, context: str) -> str:
messages = [
{
"role": "system",
"content": (
"You are a medical QA assistant. "
"Answer using ONLY the provided context. "
"If the answer is not present in the context, say exactly: "
"'Not found in the provided documents.'"
),
},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"},
]
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
def generate_answer(question: str, context: str) -> str:
prompt = _build_qwen_prompt(question, context)
out = GENERATOR(
prompt,
max_new_tokens=256,
temperature=0.2,
do_sample=True,
return_full_text=False,
)
return out[0]["generated_text"].strip()
def create_session(chunks: List[str], docs: List[Dict[str, Any]]) -> str:
"""
docs: list of {"doc_id": int, "filename": str, "num_chunks": int}
"""
embeddings = EMBEDDER.encode(chunks, convert_to_numpy=True)
index = build_faiss_index(embeddings)
session_id = str(uuid.uuid4())
SESSIONS[session_id] = {
"chunks": chunks,
"index": index,
"created_at": time.time(),
"docs": docs,
}
return session_id