MohitGupta41 commited on
Commit
4d922fd
·
1 Parent(s): 4575791

FastAPI RAG backend (Docker)

Browse files
Files changed (2) hide show
  1. rag.py +219 -23
  2. requirements.txt +1 -0
rag.py CHANGED
@@ -1,7 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import uuid
2
- from dataclasses import dataclass
3
- from typing import List, Dict, Any, Tuple
4
  import time
 
5
 
6
  import numpy as np
7
  import faiss
@@ -10,85 +126,165 @@ from sentence_transformers import SentenceTransformer
10
  # PDF extraction
11
  import fitz # pymupdf
12
 
13
- # LLM (choose 1)
14
- from transformers import pipeline
 
15
 
16
 
17
  # -----------------------------
18
  # Globals (MVP)
19
  # -----------------------------
 
20
  EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2")
21
 
22
- # For MVP: use a smallish instruct model if possible
23
- # NOTE: Mistral 7B is heavy; if you can't run it locally, use a smaller HF model.
 
 
 
 
 
 
 
 
 
 
 
24
  GENERATOR = pipeline(
25
- "text2text-generation",
26
- model="google/flan-t5-base",
27
- max_new_tokens=256
28
  )
29
 
30
- SESSIONS: Dict[str, Dict[str, Any]] = {} # session_id -> {chunks, index, created_at}
 
31
 
32
 
33
  # -----------------------------
34
  # Helpers
35
  # -----------------------------
36
  def extract_text_from_pdf(pdf_bytes: bytes) -> str:
 
 
 
 
37
  doc = fitz.open(stream=pdf_bytes, filetype="pdf")
38
  pages = []
39
  for page in doc:
40
  pages.append(page.get_text("text"))
41
  text = "\n".join(pages).strip()
42
- print(text)
43
  return text
44
 
45
 
46
-
47
  def chunk_text(text: str, chunk_size_words: int = 350, overlap_words: int = 60) -> List[str]:
 
 
 
 
 
 
 
 
48
  words = text.split()
49
- chunks = []
50
  step = max(1, chunk_size_words - overlap_words)
 
51
  for i in range(0, len(words), step):
52
  chunk = words[i:i + chunk_size_words]
53
  if chunk:
54
  chunks.append(" ".join(chunk))
55
- return chunks
56
 
 
57
 
58
 
59
  def build_faiss_index(vectors: np.ndarray) -> faiss.Index:
 
 
 
 
60
  vectors = vectors.astype("float32")
61
  dim = vectors.shape[1]
62
- index = faiss.IndexFlatIP(dim) # cosine-like if vectors normalized
 
 
63
  faiss.normalize_L2(vectors)
64
  index.add(vectors)
 
65
  return index
66
 
67
 
68
- def retrieve_top_k(query: str, chunks: List[str], index: faiss.Index, k: int = 3) -> List[Tuple[int, float, str]]:
 
 
 
 
 
 
 
 
69
  q = EMBEDDER.encode([query], convert_to_numpy=True).astype("float32")
70
  faiss.normalize_L2(q)
 
71
  scores, ids = index.search(q, k)
72
- results = []
 
73
  for rank, idx in enumerate(ids[0]):
74
  if idx == -1:
75
  continue
76
  results.append((int(idx), float(scores[0][rank]), chunks[int(idx)]))
 
77
  return results
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def generate_answer(question: str, context: str) -> str:
81
- prompt = (
82
- "Answer using ONLY the provided context. "
83
- "If not found in the context, say: Not found in the provided documents.\n\n"
84
- f"Context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:"
 
 
 
 
 
 
 
85
  )
86
 
87
- out = GENERATOR(prompt)
88
  return out[0]["generated_text"].strip()
89
 
90
 
91
  def create_session(chunks: List[str]) -> str:
 
 
 
92
  embeddings = EMBEDDER.encode(chunks, convert_to_numpy=True)
93
  index = build_faiss_index(embeddings)
94
 
@@ -96,6 +292,6 @@ def create_session(chunks: List[str]) -> str:
96
  SESSIONS[session_id] = {
97
  "chunks": chunks,
98
  "index": index,
99
- "created_at": time.time()
100
  }
101
  return session_id
 
1
+ # import uuid
2
+ # from dataclasses import dataclass
3
+ # from typing import List, Dict, Any, Tuple
4
+ # import time
5
+
6
+ # import numpy as np
7
+ # import faiss
8
+ # from sentence_transformers import SentenceTransformer
9
+
10
+ # # PDF extraction
11
+ # import fitz # pymupdf
12
+
13
+ # # LLM (choose 1)
14
+ # from transformers import pipeline
15
+
16
+
17
+ # # -----------------------------
18
+ # # Globals (MVP)
19
+ # # -----------------------------
20
+ # EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2")
21
+
22
+ # # For MVP: use a smallish instruct model if possible
23
+ # # NOTE: Mistral 7B is heavy; if you can't run it locally, use a smaller HF model.
24
+ # GENERATOR = pipeline(
25
+ # "text2text-generation",
26
+ # model="google/flan-t5-base",
27
+ # max_new_tokens=256
28
+ # )
29
+
30
+ # SESSIONS: Dict[str, Dict[str, Any]] = {} # session_id -> {chunks, index, created_at}
31
+
32
+
33
+ # # -----------------------------
34
+ # # Helpers
35
+ # # -----------------------------
36
+ # def extract_text_from_pdf(pdf_bytes: bytes) -> str:
37
+ # doc = fitz.open(stream=pdf_bytes, filetype="pdf")
38
+ # pages = []
39
+ # for page in doc:
40
+ # pages.append(page.get_text("text"))
41
+ # text = "\n".join(pages).strip()
42
+ # print(text)
43
+ # return text
44
+
45
+
46
+
47
+ # def chunk_text(text: str, chunk_size_words: int = 350, overlap_words: int = 60) -> List[str]:
48
+ # words = text.split()
49
+ # chunks = []
50
+ # step = max(1, chunk_size_words - overlap_words)
51
+ # for i in range(0, len(words), step):
52
+ # chunk = words[i:i + chunk_size_words]
53
+ # if chunk:
54
+ # chunks.append(" ".join(chunk))
55
+ # return chunks
56
+
57
+
58
+
59
+ # def build_faiss_index(vectors: np.ndarray) -> faiss.Index:
60
+ # vectors = vectors.astype("float32")
61
+ # dim = vectors.shape[1]
62
+ # index = faiss.IndexFlatIP(dim) # cosine-like if vectors normalized
63
+ # faiss.normalize_L2(vectors)
64
+ # index.add(vectors)
65
+ # return index
66
+
67
+
68
+ # def retrieve_top_k(query: str, chunks: List[str], index: faiss.Index, k: int = 3) -> List[Tuple[int, float, str]]:
69
+ # q = EMBEDDER.encode([query], convert_to_numpy=True).astype("float32")
70
+ # faiss.normalize_L2(q)
71
+ # scores, ids = index.search(q, k)
72
+ # results = []
73
+ # for rank, idx in enumerate(ids[0]):
74
+ # if idx == -1:
75
+ # continue
76
+ # results.append((int(idx), float(scores[0][rank]), chunks[int(idx)]))
77
+ # return results
78
+
79
+
80
+ # def generate_answer(question: str, context: str) -> str:
81
+ # prompt = (
82
+ # "Answer using ONLY the provided context. "
83
+ # "If not found in the context, say: Not found in the provided documents.\n\n"
84
+ # f"Context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:"
85
+ # )
86
+
87
+ # out = GENERATOR(prompt)
88
+ # return out[0]["generated_text"].strip()
89
+
90
+
91
+ # def create_session(chunks: List[str]) -> str:
92
+ # embeddings = EMBEDDER.encode(chunks, convert_to_numpy=True)
93
+ # index = build_faiss_index(embeddings)
94
+
95
+ # session_id = str(uuid.uuid4())
96
+ # SESSIONS[session_id] = {
97
+ # "chunks": chunks,
98
+ # "index": index,
99
+ # "created_at": time.time()
100
+ # }
101
+ # return session_id
102
+
103
+
104
+
105
+
106
+
107
+
108
+
109
+
110
+
111
+
112
+
113
+
114
+
115
+
116
+
117
+ # rag.py
118
  import uuid
 
 
119
  import time
120
+ from typing import List, Dict, Any, Tuple
121
 
122
  import numpy as np
123
  import faiss
 
126
  # PDF extraction
127
  import fitz # pymupdf
128
 
129
+ # LLM (Qwen)
130
+ import torch
131
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
132
 
133
 
134
  # -----------------------------
135
  # Globals (MVP)
136
  # -----------------------------
137
+ # Embeddings model (fast + solid baseline)
138
  EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2")
139
 
140
+ # Qwen Instruct model (better than flan-t5-base)
141
+ QWEN_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct"
142
+
143
+ # Load tokenizer + model
144
+ tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID)
145
+
146
+ model = AutoModelForCausalLM.from_pretrained(
147
+ QWEN_MODEL_ID,
148
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
149
+ device_map="auto", # uses GPU if available; otherwise CPU
150
+ )
151
+
152
+ # Text-generation pipeline for CausalLMs
153
  GENERATOR = pipeline(
154
+ "text-generation",
155
+ model=model,
156
+ tokenizer=tokenizer,
157
  )
158
 
159
+ # In-memory session store: session_id -> {chunks, index, created_at}
160
+ SESSIONS: Dict[str, Dict[str, Any]] = {}
161
 
162
 
163
  # -----------------------------
164
  # Helpers
165
  # -----------------------------
166
  def extract_text_from_pdf(pdf_bytes: bytes) -> str:
167
+ """
168
+ Extract plain text from a PDF using PyMuPDF.
169
+ Note: For scanned/image PDFs, you'll need OCR (out of scope for MVP).
170
+ """
171
  doc = fitz.open(stream=pdf_bytes, filetype="pdf")
172
  pages = []
173
  for page in doc:
174
  pages.append(page.get_text("text"))
175
  text = "\n".join(pages).strip()
 
176
  return text
177
 
178
 
 
179
  def chunk_text(text: str, chunk_size_words: int = 350, overlap_words: int = 60) -> List[str]:
180
+ """
181
+ Word-window chunking with overlap.
182
+
183
+ chunk_size_words: size of each chunk window
184
+ overlap_words: how many words to overlap between chunks
185
+
186
+ step = chunk_size_words - overlap_words
187
+ """
188
  words = text.split()
189
+ chunks: List[str] = []
190
  step = max(1, chunk_size_words - overlap_words)
191
+
192
  for i in range(0, len(words), step):
193
  chunk = words[i:i + chunk_size_words]
194
  if chunk:
195
  chunks.append(" ".join(chunk))
 
196
 
197
+ return chunks
198
 
199
 
200
  def build_faiss_index(vectors: np.ndarray) -> faiss.Index:
201
+ """
202
+ Build a FAISS index using inner product (IP). If vectors are L2-normalized,
203
+ IP approximates cosine similarity.
204
+ """
205
  vectors = vectors.astype("float32")
206
  dim = vectors.shape[1]
207
+
208
+ # Inner product index (cosine-like after normalization)
209
+ index = faiss.IndexFlatIP(dim)
210
  faiss.normalize_L2(vectors)
211
  index.add(vectors)
212
+
213
  return index
214
 
215
 
216
+ def retrieve_top_k(
217
+ query: str,
218
+ chunks: List[str],
219
+ index: faiss.Index,
220
+ k: int = 3
221
+ ) -> List[Tuple[int, float, str]]:
222
+ """
223
+ Embed the query, search FAISS, and return (chunk_id, score, chunk_text).
224
+ """
225
  q = EMBEDDER.encode([query], convert_to_numpy=True).astype("float32")
226
  faiss.normalize_L2(q)
227
+
228
  scores, ids = index.search(q, k)
229
+
230
+ results: List[Tuple[int, float, str]] = []
231
  for rank, idx in enumerate(ids[0]):
232
  if idx == -1:
233
  continue
234
  results.append((int(idx), float(scores[0][rank]), chunks[int(idx)]))
235
+
236
  return results
237
 
238
 
239
+ def _build_qwen_prompt(question: str, context: str) -> str:
240
+ """
241
+ Build a chat-formatted prompt using Qwen's chat template for better instruction following.
242
+ """
243
+ messages = [
244
+ {
245
+ "role": "system",
246
+ "content": (
247
+ "You are a medical QA assistant. "
248
+ "Answer using ONLY the provided context. "
249
+ "If the answer is not present in the context, say exactly: "
250
+ "'Not found in the provided documents.'"
251
+ ),
252
+ },
253
+ {
254
+ "role": "user",
255
+ "content": f"Context:\n{context}\n\nQuestion:\n{question}",
256
+ },
257
+ ]
258
+
259
+ prompt = tokenizer.apply_chat_template(
260
+ messages,
261
+ tokenize=False,
262
+ add_generation_prompt=True
263
+ )
264
+ return prompt
265
+
266
+
267
  def generate_answer(question: str, context: str) -> str:
268
+ """
269
+ Generate an answer grounded strictly in retrieved context using Qwen Instruct.
270
+ """
271
+ prompt = _build_qwen_prompt(question, context)
272
+
273
+ out = GENERATOR(
274
+ prompt,
275
+ max_new_tokens=256,
276
+ temperature=0.2,
277
+ do_sample=True,
278
+ return_full_text=False
279
  )
280
 
 
281
  return out[0]["generated_text"].strip()
282
 
283
 
284
  def create_session(chunks: List[str]) -> str:
285
+ """
286
+ Create a retrieval session by embedding chunks and building a FAISS index.
287
+ """
288
  embeddings = EMBEDDER.encode(chunks, convert_to_numpy=True)
289
  index = build_faiss_index(embeddings)
290
 
 
292
  SESSIONS[session_id] = {
293
  "chunks": chunks,
294
  "index": index,
295
+ "created_at": time.time(),
296
  }
297
  return session_id
requirements.txt CHANGED
@@ -7,3 +7,4 @@ faiss-cpu
7
  pymupdf
8
  transformers
9
  torch
 
 
7
  pymupdf
8
  transformers
9
  torch
10
+ accelerate