swathibp commited on
Commit
36de92b
·
verified ·
1 Parent(s): ae6e484

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +295 -0
app.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import faiss
3
+ import pickle
4
+ import torch
5
+
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoTokenizer
8
+ from transformers import AutoModelForSeq2SeqLM
9
+
10
+
11
+ import gradio as gr
12
+
13
+
14
+ # ==================================================
15
+ # CONFIG
16
+ # ==================================================
17
+
18
+ CONFIG = {
19
+
20
+ "retriever_model_path":
21
+ "swathibp/BGE-base_finetuned",
22
+
23
+ "generator_model_path":
24
+ "swathibp/Flan_T5_merged",
25
+
26
+ "save_dir":
27
+ ".",
28
+
29
+ "top_k": 3,
30
+
31
+ "max_new_tokens": 250,
32
+
33
+ "device":
34
+ "cuda"
35
+ if torch.cuda.is_available()
36
+ else "cpu"
37
+
38
+ }
39
+
40
+ os.makedirs(
41
+ CONFIG["save_dir"],
42
+ exist_ok=True
43
+ )
44
+
45
+ print(
46
+ "DEVICE:",
47
+ CONFIG["device"]
48
+ )
49
+
50
+
51
+ # ==================================================
52
+ # BUILD / LOAD FAISS
53
+ # ==================================================
54
+
55
+ INDEX_FILE = \
56
+ f"{CONFIG['save_dir']}/index.faiss"
57
+
58
+ DOC_FILE = \
59
+ f"{CONFIG['save_dir']}/docs.pkl"
60
+
61
+
62
+ print("Loading Retriever...")
63
+
64
+ retriever = SentenceTransformer(
65
+ CONFIG["retriever_model_path"]
66
+ )
67
+
68
+ if os.path.exists(INDEX_FILE):
69
+
70
+ print("Loading Stored FAISS Index")
71
+
72
+ index = faiss.read_index(
73
+ INDEX_FILE
74
+ )
75
+
76
+ with open(
77
+ DOC_FILE,
78
+ "rb"
79
+ ) as f:
80
+
81
+ documents = pickle.load(f)
82
+
83
+
84
+ # ==================================================
85
+ # LOAD GENERATOR
86
+ # ==================================================
87
+
88
+ print("Loading FLAN Generator...")
89
+
90
+ tokenizer = \
91
+ AutoTokenizer.from_pretrained(
92
+
93
+ CONFIG[
94
+ "generator_model_path"
95
+ ]
96
+
97
+ )
98
+
99
+ generator = \
100
+ AutoModelForSeq2SeqLM.from_pretrained(
101
+
102
+ CONFIG[
103
+ "generator_model_path"
104
+ ]
105
+
106
+ ).to(
107
+ CONFIG["device"]
108
+ )
109
+
110
+ generator.eval()
111
+
112
+ print("Generator Loaded")
113
+
114
+
115
+ # ==================================================
116
+ # RETRIEVAL
117
+ # ==================================================
118
+
119
+ def retrieve(query):
120
+
121
+ emb = \
122
+ retriever.encode(
123
+
124
+ [query],
125
+
126
+ convert_to_numpy=True
127
+
128
+ )
129
+
130
+ faiss.normalize_L2(
131
+ emb
132
+ )
133
+
134
+ scores, indices = \
135
+ index.search(
136
+
137
+ emb,
138
+
139
+ CONFIG["top_k"]
140
+
141
+ )
142
+
143
+ docs = []
144
+
145
+ for idx in indices[0]:
146
+
147
+ docs.append(
148
+ documents[idx]
149
+ )
150
+
151
+ return docs
152
+
153
+
154
+ # ==================================================
155
+ # GENERATION
156
+ # ==================================================
157
+
158
+ def generate(query):
159
+
160
+ docs = retrieve(query)
161
+
162
+ instruction = (
163
+
164
+ "Answer ONLY using the information provided in the context. "
165
+
166
+ "If the answer is not available, reply exactly: "
167
+
168
+ "'Not found in the provided documents.'"
169
+
170
+ )
171
+
172
+ context = "\n".join(
173
+ docs
174
+ )
175
+
176
+ prompt = f"""
177
+
178
+ {instruction}
179
+
180
+ Context:
181
+
182
+ {context}
183
+
184
+ Question:
185
+
186
+ {query}
187
+
188
+ Answer:
189
+
190
+ """
191
+
192
+ inputs = tokenizer(
193
+
194
+ prompt,
195
+
196
+ return_tensors="pt",
197
+
198
+ truncation=True
199
+
200
+ ).to(
201
+ CONFIG["device"]
202
+ )
203
+
204
+ with torch.no_grad():
205
+
206
+ outputs = \
207
+ generator.generate(
208
+
209
+ **inputs,
210
+
211
+ max_new_tokens=
212
+ CONFIG[
213
+ "max_new_tokens"
214
+ ],
215
+
216
+ do_sample=False,
217
+
218
+ early_stopping=True
219
+
220
+ )
221
+
222
+ answer = \
223
+ tokenizer.decode(
224
+
225
+ outputs[0],
226
+
227
+ skip_special_tokens=True
228
+
229
+ )
230
+
231
+ return answer, context
232
+
233
+
234
+ # ==================================================
235
+ # UI
236
+ # ==================================================
237
+
238
+ with gr.Blocks() as demo:
239
+
240
+ gr.Markdown(
241
+ "# MAHE QA System"
242
+ )
243
+
244
+ q = gr.Textbox(
245
+
246
+ label="Question",
247
+
248
+ placeholder=
249
+ "Enter your MAHE question here...",
250
+
251
+ lines=3,
252
+
253
+ max_lines=5
254
+ )
255
+
256
+ ask = gr.Button(
257
+ "Generate Answer"
258
+ )
259
+
260
+ ans = gr.Textbox(
261
+
262
+ label="Answer",
263
+
264
+ lines=15,
265
+
266
+ max_lines=30,
267
+
268
+ #show_copy_button=True
269
+ )
270
+
271
+ ctx = gr.Textbox(
272
+
273
+ label="Retrieved Context",
274
+
275
+ lines=20,
276
+
277
+ max_lines=40,
278
+
279
+ #show_copy_button=True
280
+ )
281
+
282
+ ask.click(
283
+
284
+ generate,
285
+
286
+ q,
287
+
288
+ [ans, ctx]
289
+
290
+ )
291
+
292
+ demo.launch(
293
+ share=True,
294
+ debug=True
295
+ )