# rag.py import uuid import time from typing import List, Dict, Any, Tuple import numpy as np import faiss from sentence_transformers import SentenceTransformer # PDF extraction import fitz # pymupdf # LLM (Qwen) import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline # ----------------------------- # Globals (MVP) # ----------------------------- EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2") QWEN_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct" tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID) model = AutoModelForCausalLM.from_pretrained( QWEN_MODEL_ID, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", ) GENERATOR = pipeline( "text-generation", model=model, tokenizer=tokenizer, ) # session_id -> {chunks, index, created_at, docs} SESSIONS: Dict[str, Dict[str, Any]] = {} # ----------------------------- # Helpers # ----------------------------- def extract_text_from_pdf(pdf_bytes: bytes) -> str: doc = fitz.open(stream=pdf_bytes, filetype="pdf") pages = [] for page in doc: pages.append(page.get_text("text")) return "\n".join(pages).strip() def chunk_text(text: str, chunk_size_words: int = 350, overlap_words: int = 60) -> List[str]: words = text.split() chunks: List[str] = [] step = max(1, chunk_size_words - overlap_words) for i in range(0, len(words), step): chunk = words[i:i + chunk_size_words] if chunk: chunks.append(" ".join(chunk)) return chunks def build_faiss_index(vectors: np.ndarray) -> faiss.Index: vectors = vectors.astype("float32") dim = vectors.shape[1] index = faiss.IndexFlatIP(dim) faiss.normalize_L2(vectors) index.add(vectors) return index def retrieve_top_k( query: str, chunks: List[str], index: faiss.Index, k: int = 3 ) -> List[Tuple[int, float, str]]: q = EMBEDDER.encode([query], convert_to_numpy=True).astype("float32") faiss.normalize_L2(q) scores, ids = index.search(q, k) results: List[Tuple[int, float, str]] = [] for rank, idx in enumerate(ids[0]): if idx == -1: continue results.append((int(idx), float(scores[0][rank]), chunks[int(idx)])) return results def _build_qwen_prompt(question: str, context: str) -> str: messages = [ { "role": "system", "content": ( "You are a medical QA assistant. " "Answer using ONLY the provided context. " "If the answer is not present in the context, say exactly: " "'Not found in the provided documents.'" ), }, {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}, ] return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) def generate_answer(question: str, context: str) -> str: prompt = _build_qwen_prompt(question, context) out = GENERATOR( prompt, max_new_tokens=256, temperature=0.2, do_sample=True, return_full_text=False, ) return out[0]["generated_text"].strip() def create_session(chunks: List[str], docs: List[Dict[str, Any]]) -> str: """ docs: list of {"doc_id": int, "filename": str, "num_chunks": int} """ embeddings = EMBEDDER.encode(chunks, convert_to_numpy=True) index = build_faiss_index(embeddings) session_id = str(uuid.uuid4()) SESSIONS[session_id] = { "chunks": chunks, "index": index, "created_at": time.time(), "docs": docs, } return session_id