Spaces:
Sleeping
Sleeping
| # rag.py | |
| import uuid | |
| import time | |
| from typing import List, Dict, Any, Tuple | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| # PDF extraction | |
| import fitz # pymupdf | |
| # LLM (Qwen) | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| # ----------------------------- | |
| # Globals (MVP) | |
| # ----------------------------- | |
| EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2") | |
| QWEN_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| QWEN_MODEL_ID, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto", | |
| ) | |
| GENERATOR = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| ) | |
| # session_id -> {chunks, index, created_at, docs} | |
| SESSIONS: Dict[str, Dict[str, Any]] = {} | |
| # ----------------------------- | |
| # Helpers | |
| # ----------------------------- | |
| def extract_text_from_pdf(pdf_bytes: bytes) -> str: | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| pages = [] | |
| for page in doc: | |
| pages.append(page.get_text("text")) | |
| return "\n".join(pages).strip() | |
| def chunk_text(text: str, chunk_size_words: int = 350, overlap_words: int = 60) -> List[str]: | |
| words = text.split() | |
| chunks: List[str] = [] | |
| step = max(1, chunk_size_words - overlap_words) | |
| for i in range(0, len(words), step): | |
| chunk = words[i:i + chunk_size_words] | |
| if chunk: | |
| chunks.append(" ".join(chunk)) | |
| return chunks | |
| def build_faiss_index(vectors: np.ndarray) -> faiss.Index: | |
| vectors = vectors.astype("float32") | |
| dim = vectors.shape[1] | |
| index = faiss.IndexFlatIP(dim) | |
| faiss.normalize_L2(vectors) | |
| index.add(vectors) | |
| return index | |
| def retrieve_top_k( | |
| query: str, | |
| chunks: List[str], | |
| index: faiss.Index, | |
| k: int = 3 | |
| ) -> List[Tuple[int, float, str]]: | |
| q = EMBEDDER.encode([query], convert_to_numpy=True).astype("float32") | |
| faiss.normalize_L2(q) | |
| scores, ids = index.search(q, k) | |
| results: List[Tuple[int, float, str]] = [] | |
| for rank, idx in enumerate(ids[0]): | |
| if idx == -1: | |
| continue | |
| results.append((int(idx), float(scores[0][rank]), chunks[int(idx)])) | |
| return results | |
| def _build_qwen_prompt(question: str, context: str) -> str: | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a medical QA assistant. " | |
| "Answer using ONLY the provided context. " | |
| "If the answer is not present in the context, say exactly: " | |
| "'Not found in the provided documents.'" | |
| ), | |
| }, | |
| {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}, | |
| ] | |
| return tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| def generate_answer(question: str, context: str) -> str: | |
| prompt = _build_qwen_prompt(question, context) | |
| out = GENERATOR( | |
| prompt, | |
| max_new_tokens=256, | |
| temperature=0.2, | |
| do_sample=True, | |
| return_full_text=False, | |
| ) | |
| return out[0]["generated_text"].strip() | |
| def create_session(chunks: List[str], docs: List[Dict[str, Any]]) -> str: | |
| """ | |
| docs: list of {"doc_id": int, "filename": str, "num_chunks": int} | |
| """ | |
| embeddings = EMBEDDER.encode(chunks, convert_to_numpy=True) | |
| index = build_faiss_index(embeddings) | |
| session_id = str(uuid.uuid4()) | |
| SESSIONS[session_id] = { | |
| "chunks": chunks, | |
| "index": index, | |
| "created_at": time.time(), | |
| "docs": docs, | |
| } | |
| return session_id | |