swathibp commited on
Commit
ae6e484
·
verified ·
1 Parent(s): 80729b3

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -295
app.py DELETED
@@ -1,295 +0,0 @@
1
- import os
2
- import faiss
3
- import pickle
4
- import torch
5
-
6
- from sentence_transformers import SentenceTransformer
7
- from transformers import AutoTokenizer
8
- from transformers import AutoModelForSeq2SeqLM
9
-
10
-
11
- import gradio as gr
12
-
13
-
14
- # ==================================================
15
- # CONFIG
16
- # ==================================================
17
-
18
- CONFIG = {
19
-
20
- "retriever_model_path":
21
- "swathibp/BGE-base_finetuned",
22
-
23
- "generator_model_path":
24
- "swathibp/Flan_T5_merged",
25
-
26
- "save_dir":
27
- ".",
28
-
29
- "top_k": 3,
30
-
31
- "max_new_tokens": 250,
32
-
33
- "device":
34
- "cuda"
35
- if torch.cuda.is_available()
36
- else "cpu"
37
-
38
- }
39
-
40
- os.makedirs(
41
- CONFIG["save_dir"],
42
- exist_ok=True
43
- )
44
-
45
- print(
46
- "DEVICE:",
47
- CONFIG["device"]
48
- )
49
-
50
-
51
- # ==================================================
52
- # BUILD / LOAD FAISS
53
- # ==================================================
54
-
55
- INDEX_FILE = \
56
- f"{CONFIG['save_dir']}/index.faiss"
57
-
58
- DOC_FILE = \
59
- f"{CONFIG['save_dir']}/docs.pkl"
60
-
61
-
62
- print("Loading Retriever...")
63
-
64
- retriever = SentenceTransformer(
65
- CONFIG["retriever_model_path"]
66
- )
67
-
68
- if os.path.exists(INDEX_FILE):
69
-
70
- print("Loading Stored FAISS Index")
71
-
72
- index = faiss.read_index(
73
- INDEX_FILE
74
- )
75
-
76
- with open(
77
- DOC_FILE,
78
- "rb"
79
- ) as f:
80
-
81
- documents = pickle.load(f)
82
-
83
-
84
- # ==================================================
85
- # LOAD GENERATOR
86
- # ==================================================
87
-
88
- print("Loading FLAN Generator...")
89
-
90
- tokenizer = \
91
- AutoTokenizer.from_pretrained(
92
-
93
- CONFIG[
94
- "generator_model_path"
95
- ]
96
-
97
- )
98
-
99
- generator = \
100
- AutoModelForSeq2SeqLM.from_pretrained(
101
-
102
- CONFIG[
103
- "generator_model_path"
104
- ]
105
-
106
- ).to(
107
- CONFIG["device"]
108
- )
109
-
110
- generator.eval()
111
-
112
- print("Generator Loaded")
113
-
114
-
115
- # ==================================================
116
- # RETRIEVAL
117
- # ==================================================
118
-
119
- def retrieve(query):
120
-
121
- emb = \
122
- retriever.encode(
123
-
124
- [query],
125
-
126
- convert_to_numpy=True
127
-
128
- )
129
-
130
- faiss.normalize_L2(
131
- emb
132
- )
133
-
134
- scores, indices = \
135
- index.search(
136
-
137
- emb,
138
-
139
- CONFIG["top_k"]
140
-
141
- )
142
-
143
- docs = []
144
-
145
- for idx in indices[0]:
146
-
147
- docs.append(
148
- documents[idx]
149
- )
150
-
151
- return docs
152
-
153
-
154
- # ==================================================
155
- # GENERATION
156
- # ==================================================
157
-
158
- def generate(query):
159
-
160
- docs = retrieve(query)
161
-
162
- instruction = (
163
-
164
- "Answer ONLY using the information provided in the context. "
165
-
166
- "If the answer is not available, reply exactly: "
167
-
168
- "'Not found in the provided documents.'"
169
-
170
- )
171
-
172
- context = "\n".join(
173
- docs
174
- )
175
-
176
- prompt = f"""
177
-
178
- {instruction}
179
-
180
- Context:
181
-
182
- {context}
183
-
184
- Question:
185
-
186
- {query}
187
-
188
- Answer:
189
-
190
- """
191
-
192
- inputs = tokenizer(
193
-
194
- prompt,
195
-
196
- return_tensors="pt",
197
-
198
- truncation=True
199
-
200
- ).to(
201
- CONFIG["device"]
202
- )
203
-
204
- with torch.no_grad():
205
-
206
- outputs = \
207
- generator.generate(
208
-
209
- **inputs,
210
-
211
- max_new_tokens=
212
- CONFIG[
213
- "max_new_tokens"
214
- ],
215
-
216
- do_sample=False,
217
-
218
- early_stopping=True
219
-
220
- )
221
-
222
- answer = \
223
- tokenizer.decode(
224
-
225
- outputs[0],
226
-
227
- skip_special_tokens=True
228
-
229
- )
230
-
231
- return answer, context
232
-
233
-
234
- # ==================================================
235
- # UI
236
- # ==================================================
237
-
238
- with gr.Blocks() as demo:
239
-
240
- gr.Markdown(
241
- "# MAHE QA System"
242
- )
243
-
244
- q = gr.Textbox(
245
-
246
- label="Question",
247
-
248
- placeholder=
249
- "Enter your MAHE question here...",
250
-
251
- lines=3,
252
-
253
- max_lines=5
254
- )
255
-
256
- ask = gr.Button(
257
- "Generate Answer"
258
- )
259
-
260
- ans = gr.Textbox(
261
-
262
- label="Answer",
263
-
264
- lines=15,
265
-
266
- max_lines=30,
267
-
268
- #show_copy_button=True
269
- )
270
-
271
- ctx = gr.Textbox(
272
-
273
- label="Retrieved Context",
274
-
275
- lines=20,
276
-
277
- max_lines=40,
278
-
279
- show_copy_button=True
280
- )
281
-
282
- ask.click(
283
-
284
- generate,
285
-
286
- q,
287
-
288
- [ans, ctx]
289
-
290
- )
291
-
292
- demo.launch(
293
- share=True,
294
- debug=True
295
- )