Oralyz commited on
Commit
b6281eb
·
verified ·
1 Parent(s): 247233b

Upload rag_config.py

Browse files
Files changed (1) hide show
  1. rag_config.py +421 -0
rag_config.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """rag_config.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1IUUxrU5dDjy-Ap_49dbJoGVZaL_ndrf8
8
+ """
9
+
10
+ # Imports
11
+
12
+ # General imports
13
+ import numpy as np
14
+ import re
15
+ import pandas as pd
16
+
17
+ # Pytorch and transformers (for LLM)
18
+ import transformers, torch
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, AutoModel
20
+ transformers.logging.set_verbosity_info()
21
+
22
+ # For loading documents from a path
23
+ from pathlib import Path
24
+
25
+ # For the embedding module
26
+ from sentence_transformers import SentenceTransformer
27
+
28
+ # %%
29
+
30
+ # Load device
31
+
32
+ if torch.backends.mps.is_available():
33
+ device = torch.device("mps")
34
+ elif torch.cuda.is_available():
35
+ device = torch.device("cuda")
36
+ else:
37
+ device =torch.device("cpu")
38
+
39
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
40
+ import torch
41
+
42
+ class FoundationModel():
43
+
44
+ def __init__(self, FOUND_MODEL_PATH, TEMPERATURE=0.7, MAX_NEW_TOKENS=1024):
45
+
46
+ self.model = AutoModelForCausalLM.from_pretrained(
47
+ FOUND_MODEL_PATH,
48
+ torch_dtype="auto",
49
+ trust_remote_code=True
50
+ ).to(device)
51
+
52
+ self.tokenizer = AutoTokenizer.from_pretrained(FOUND_MODEL_PATH)
53
+
54
+ # Generation config
55
+ self.model.generation_config.temperature = TEMPERATURE
56
+ self.model.generation_config.top_p = None
57
+
58
+ self.llm = pipeline(
59
+ "text-generation",
60
+ model=self.model,
61
+ tokenizer=self.tokenizer,
62
+ return_full_text=False,
63
+ max_new_tokens=MAX_NEW_TOKENS,
64
+ do_sample=True, max_length = None
65
+ )
66
+
67
+ self.num_parameters = self.model.num_parameters()
68
+ print('Number of parameters in my model',
69
+ '{:.2e}'.format(self.num_parameters))
70
+
71
+
72
+ # 🔹 Simple prompt (no chat template)
73
+ def generate_response(self, prompt):
74
+
75
+ formatted_prompt = f"""
76
+ You are a medical assistant.
77
+
78
+ Use the following context to answer the question.
79
+
80
+ IMPORTANT RULES:
81
+ - Do not mention the context.
82
+ - Do not mention figures or sections.
83
+ - Do not say "according to the context".
84
+ - Give a clear explanation as if you are speaking to a patient.
85
+ Question:
86
+ {prompt}
87
+
88
+ Answer:
89
+ """
90
+
91
+ output = self.llm(formatted_prompt)
92
+ return output[0][""]
93
+
94
+
95
+ # 🔹 RAG version (context aware)
96
+ def generate_response_with_context(self, prompt, context):
97
+
98
+ full_prompt = f"""
99
+
100
+
101
+ You are a dental pathology expert.
102
+ Answer strictly using the provided context.
103
+ If the answer is not in the context, say: I don't know.
104
+
105
+ IMPORTANT RULES:
106
+ - Do not mention the context.
107
+ - Do not mention figures or sections.
108
+ - Do not say "according to the context".
109
+ - Give a clear explanation as if you are speaking to a patient.
110
+ Question:
111
+ {prompt}
112
+
113
+ Answer:
114
+
115
+ """
116
+
117
+ if context:
118
+ for i, ctx in enumerate(context):
119
+ full_prompt += f"\nContext {i+1}:\n{ctx}\n"
120
+
121
+ full_prompt += f"\nQuestion:\n{prompt}\nAnswer:"
122
+
123
+ output = self.llm(full_prompt)
124
+
125
+ return output
126
+
127
+ class EmbeddingModel():
128
+
129
+ def __init__(self,EMBEDD_MODEL_PATH):
130
+
131
+ # EMBEDD_MODEL_PATH is the name of the embedding model used within the SentenceTransformer lib
132
+
133
+ self.Embedmodel=SentenceTransformer(EMBEDD_MODEL_PATH).to(device)
134
+ self.dim=SentenceTransformer(EMBEDD_MODEL_PATH).get_sentence_embedding_dimension()
135
+
136
+
137
+ def get_embeddings(self,texts):
138
+
139
+ # texts is a list of strings (which is supposed to be the list of chinks; without the source)
140
+ # we return embeddings of torch type with shape (len(texts),self.dim)
141
+
142
+ embeddings=self.Embedmodel.encode(texts,convert_to_tensor=True).to(device)
143
+ return embeddings
144
+
145
+
146
+ def compute_cos_sim_embed(self,embed1,embed2):
147
+
148
+ # embed1,embeds2 are two embeddings of shape (1,dim)
149
+ # We compute the cos-similarity of two texts (it is returned as a float)
150
+
151
+ embed1=embed1.view(-1)
152
+ embed2=embed2.view(-1)
153
+
154
+ norm1=torch.norm(embed1,p=2,dim=0)
155
+ norm2=torch.norm(embed2,p=2,dim=0)
156
+
157
+ scal = torch.dot(embed1,embed2)
158
+
159
+ return scal.item()/(norm1.item()*norm2.item())
160
+
161
+
162
+ def compute_cos_sim_texts(self,text_1,text_2):
163
+
164
+ # text1,text2 are two str
165
+ # We compute the cos-similarity of two texts (it is returned as a float)
166
+
167
+ embeds = self.get_embeddings(texts=[text_1,text_2])
168
+
169
+ return self.compute_cos_sim_embed(embeds[0],embeds[1])
170
+
171
+ class Chunk():
172
+
173
+ def __init__(self,source,content,embed_model: EmbeddingModel):
174
+
175
+ self.embedding_model=embed_model
176
+
177
+ #dim is the common dimension of the embeddings
178
+ dim = self.embedding_model.dim
179
+
180
+ # A chunk is defined by its source (str); its content (str); its embedding (a torch which shape (1,dim))
181
+
182
+ self.source=str(source)
183
+ self.content=str(content)
184
+ self.embedding=self.embedding_model.get_embeddings(texts=[content]).reshape(1,dim)
185
+
186
+
187
+ def print_chunk(self):
188
+
189
+ print('source:',self.source,'content:',self.content,'embedding shape:',self.embedding.shape)
190
+
191
+
192
+
193
+ ### Splitter cherche seulemnt les fichiers .pdf --------> change si besoin
194
+ from pathlib import Path
195
+ from pypdf import PdfReader
196
+ class Splitter():
197
+
198
+ def __init__(self,embed_model: EmbeddingModel):
199
+
200
+ self.embedding_model=embed_model
201
+
202
+ self.docs = []
203
+ # We store the original documents as a list of .txt files (format is {"source":'File_name',"content_page":(str)})
204
+ self.chunks=[]
205
+ # This will be the list of chunks
206
+
207
+
208
+
209
+ def get_documents(self, path_doc):
210
+
211
+ docs = []
212
+
213
+ path = Path(path_doc)
214
+
215
+ # cas dossier contenant plusieurs pdf
216
+ if path.is_dir():
217
+
218
+ for file in path.rglob("*.pdf"):
219
+
220
+ reader = PdfReader(file)
221
+
222
+ for i, page in enumerate(reader.pages):
223
+
224
+ text = page.extract_text()
225
+
226
+ if text:
227
+ docs.append({
228
+ "source": f"{file.name}_page_{i+1}",
229
+ "content_page": text.strip()
230
+ })
231
+
232
+ # cas fichier pdf unique
233
+ elif path_doc.endswith(".pdf"):
234
+
235
+ reader = PdfReader(path_doc)
236
+
237
+ for i, page in enumerate(reader.pages):
238
+
239
+ text = page.extract_text()
240
+
241
+ if text:
242
+ docs.append({
243
+ "source": f"{Path(path_doc).name}_page_{i+1}",
244
+ "content_page": text.strip()
245
+ })
246
+
247
+ self.docs = docs
248
+
249
+
250
+ def get_chunks_contents_from_1_doc(self,file_name,content_page,chunk_size,overlap,sentence_split=False):
251
+
252
+ if chunk_size < overlap:
253
+ raise Exception('Careful overlap must be smaller than chunk_size')
254
+
255
+ # Now we chunk according to chunk size and overlap
256
+
257
+ if sentence_split:
258
+
259
+ content=content_page.split(".")
260
+
261
+ for text in content:
262
+
263
+ text = text.lstrip()
264
+
265
+ if not text=="":
266
+ self.chunks.append(Chunk(source=file_name,
267
+ content=text,embed_model=self.embedding_model))
268
+
269
+ else:
270
+
271
+ current = 0
272
+
273
+ while current < len(content_page):
274
+ end = min(len(content_page),current+chunk_size)
275
+ content = content_page[current:end]
276
+
277
+ self.chunks.append(Chunk(source=file_name,
278
+ content=content,embed_model=self.embedding_model))
279
+
280
+ current += chunk_size - overlap
281
+
282
+
283
+ def get_chunks(self,path_doc,chunk_size,overlap,sentence_split=False):
284
+
285
+ self.get_documents(path_doc=path_doc)
286
+
287
+ docs=self.docs
288
+
289
+ for doc in docs:
290
+
291
+ self.get_chunks_contents_from_1_doc(file_name=doc["source"],
292
+ content_page=doc["content_page"],
293
+ chunk_size=chunk_size,
294
+ overlap=overlap,
295
+ sentence_split=sentence_split)
296
+
297
+ def reset_splitter(self):
298
+
299
+ self.docs=[]
300
+ self.chunks=[]
301
+
302
+ class Retriever():
303
+
304
+ def __init__(self,embed_model: EmbeddingModel):
305
+
306
+ self.embedding_model=embed_model
307
+
308
+ # The index is a list of (Id(int),chunk); chunk needs the size DIM for the Embeddings
309
+ self.index=[]
310
+
311
+
312
+ def add_elements_to_index(self,chunks):
313
+
314
+ # chunks is a list of chunk
315
+
316
+ num = len(self.index)
317
+
318
+ for chunk in chunks:
319
+
320
+ self.index.append([num,chunk])
321
+ num+=1
322
+
323
+ def search_best(self,query,number_of_hits=3):
324
+
325
+ # query is a str
326
+
327
+ query_embed = self.embedding_model.get_embeddings(texts=[query]).to(device).reshape(1,self.embedding_model.dim)
328
+
329
+ results=[]
330
+
331
+ index=self.index
332
+
333
+ scores=[]
334
+
335
+ for item in index:
336
+
337
+ id,chunk = item
338
+
339
+ sim = self.embedding_model.compute_cos_sim_embed(embed1=query_embed,embed2=chunk.embedding)
340
+
341
+ scores.append((id,chunk,sim))
342
+
343
+ results=sorted(scores,key=lambda x:x[2],reverse=True)[:min(number_of_hits,len(index))]
344
+
345
+ return results
346
+
347
+ def reset_Retriever_index(self):
348
+
349
+ self.index=[]
350
+
351
+ class RAG():
352
+
353
+ def __init__(self,CONFIG):
354
+
355
+ self.foundation_model=FoundationModel(FOUND_MODEL_PATH=CONFIG['FOUND_MODEL_PATH'])
356
+ self.Embedding_model=EmbeddingModel(EMBEDD_MODEL_PATH=CONFIG['EMBEDD_MODEL_PATH'])
357
+ self.splitter=Splitter(self.Embedding_model)
358
+ self.retriever=Retriever(self.Embedding_model)
359
+
360
+ self.dim_embed = CONFIG['DIM_EMBED']
361
+ self.chunk_size = CONFIG['CHUNK_SIZE']
362
+ self.overlap = CONFIG['OVERLAP']
363
+
364
+
365
+ def reset_index(self):
366
+
367
+ self.retriever.reset_Retriever_index()
368
+ self.splitter.reset_splitter()
369
+
370
+
371
+ def load_documents_and_get_chunks(self,path,sentence_split=False):
372
+
373
+ self.splitter.get_chunks(path_doc=path,
374
+ chunk_size=self.chunk_size,
375
+ overlap=self.overlap,
376
+ sentence_split=sentence_split)
377
+
378
+ chunks = self.splitter.chunks
379
+
380
+ self.retriever.add_elements_to_index(chunks=chunks)
381
+
382
+
383
+ def get_retrieval(self,query,number_of_hits):
384
+
385
+ retrieved_info = self.retriever.search_best(query=query,number_of_hits=number_of_hits)
386
+
387
+ # It is the full information of the form (Id, chunk, sim)
388
+
389
+ retrieved=[]
390
+
391
+ for elem in retrieved_info:
392
+
393
+ i,chunk, distance=elem
394
+
395
+ retrieved.append({
396
+ "source": chunk.source,
397
+ "content": chunk.content
398
+ })
399
+
400
+ # We get rid of repeated items
401
+ return list(dict.fromkeys(retrieved))
402
+
403
+
404
+ def generate_response_with_context(self,query):
405
+
406
+ retrieved=self.get_retrieval(query=query,
407
+ number_of_hits=3)
408
+
409
+ return self.foundation_model.generate_response_with_context(prompt=query,
410
+ context=retrieved)
411
+
412
+ CONFIG = {
413
+ 'FOUND_MODEL_PATH':"mistralai/Mistral-7B-Instruct-v0.2", # medicalai/MedFound-7B",
414
+ 'EMBEDD_MODEL_PATH':"all-MiniLM-L6-v2",
415
+ 'DIM_EMBED':384,
416
+ 'CHUNK_SIZE':300,
417
+ 'OVERLAP':30
418
+ }
419
+
420
+
421
+