Spaces:
Sleeping
Sleeping
File size: 3,693 Bytes
4d922fd 8d5a4b2 4d922fd 8d5a4b2 4d922fd 8d5a4b2 4d922fd 6287022 4d922fd 8d5a4b2 4d922fd 8d5a4b2 6287022 4d922fd 8d5a4b2 6287022 047f5a1 8d5a4b2 4d922fd 8d5a4b2 4d922fd 8d5a4b2 4d922fd 8d5a4b2 4575791 8d5a4b2 4d922fd 8d5a4b2 4d922fd 8d5a4b2 4d922fd 8d5a4b2 4d922fd 8d5a4b2 4d922fd 6287022 4d922fd 6287022 4d922fd 8d5a4b2 4d922fd 6287022 8d5a4b2 10a955c 8d5a4b2 6287022 4d922fd 6287022 4d922fd 8d5a4b2 4d922fd 6287022 8d5a4b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# rag.py
import uuid
import time
from typing import List, Dict, Any, Tuple
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
# PDF extraction
import fitz # pymupdf
# LLM (Qwen)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# -----------------------------
# Globals (MVP)
# -----------------------------
EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2")
QWEN_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
QWEN_MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
)
GENERATOR = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
# session_id -> {chunks, index, created_at, docs}
SESSIONS: Dict[str, Dict[str, Any]] = {}
# -----------------------------
# Helpers
# -----------------------------
def extract_text_from_pdf(pdf_bytes: bytes) -> str:
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
pages = []
for page in doc:
pages.append(page.get_text("text"))
return "\n".join(pages).strip()
def chunk_text(text: str, chunk_size_words: int = 350, overlap_words: int = 60) -> List[str]:
words = text.split()
chunks: List[str] = []
step = max(1, chunk_size_words - overlap_words)
for i in range(0, len(words), step):
chunk = words[i:i + chunk_size_words]
if chunk:
chunks.append(" ".join(chunk))
return chunks
def build_faiss_index(vectors: np.ndarray) -> faiss.Index:
vectors = vectors.astype("float32")
dim = vectors.shape[1]
index = faiss.IndexFlatIP(dim)
faiss.normalize_L2(vectors)
index.add(vectors)
return index
def retrieve_top_k(
query: str,
chunks: List[str],
index: faiss.Index,
k: int = 3
) -> List[Tuple[int, float, str]]:
q = EMBEDDER.encode([query], convert_to_numpy=True).astype("float32")
faiss.normalize_L2(q)
scores, ids = index.search(q, k)
results: List[Tuple[int, float, str]] = []
for rank, idx in enumerate(ids[0]):
if idx == -1:
continue
results.append((int(idx), float(scores[0][rank]), chunks[int(idx)]))
return results
def _build_qwen_prompt(question: str, context: str) -> str:
messages = [
{
"role": "system",
"content": (
"You are a medical QA assistant. "
"Answer using ONLY the provided context. "
"If the answer is not present in the context, say exactly: "
"'Not found in the provided documents.'"
),
},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"},
]
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
def generate_answer(question: str, context: str) -> str:
prompt = _build_qwen_prompt(question, context)
out = GENERATOR(
prompt,
max_new_tokens=256,
temperature=0.2,
do_sample=True,
return_full_text=False,
)
return out[0]["generated_text"].strip()
def create_session(chunks: List[str], docs: List[Dict[str, Any]]) -> str:
"""
docs: list of {"doc_id": int, "filename": str, "num_chunks": int}
"""
embeddings = EMBEDDER.encode(chunks, convert_to_numpy=True)
index = build_faiss_index(embeddings)
session_id = str(uuid.uuid4())
SESSIONS[session_id] = {
"chunks": chunks,
"index": index,
"created_at": time.time(),
"docs": docs,
}
return session_id
|