theachyuttiwari commited on
Commit
288e608
·
1 Parent(s): 76f19cc

Upload ask.py

Browse files
Files changed (1) hide show
  1. ask.py +376 -0
ask.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import colorsys
2
+ import json
3
+ import re
4
+ import time
5
+
6
+ import nltk
7
+ import numpy as np
8
+ from nltk import tokenize
9
+
10
+ nltk.download('punkt')
11
+ from google.oauth2 import service_account
12
+ from google.cloud import texttospeech
13
+
14
+ from typing import Dict, Optional, List
15
+
16
+ import jwt
17
+ import requests
18
+ import streamlit as st
19
+ from sentence_transformers import SentenceTransformer, util, CrossEncoder
20
+
21
+ JWT_SECRET = st.secrets["api_secret"]
22
+ JWT_ALGORITHM = st.secrets["api_algorithm"]
23
+ INFERENCE_TOKEN = st.secrets["api_inference"]
24
+ CONTEXT_API_URL = st.secrets["api_context"]
25
+ LFQA_API_URL = st.secrets["api_lfqa"]
26
+
27
+ headers = {"Authorization": f"Bearer {INFERENCE_TOKEN}"}
28
+ API_URL = "https://api-inference.huggingface.co/models/vblagoje/bart_lfqa"
29
+ API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/kan-bayashi_ljspeech_joint_finetune_conformer_fastspeech2_hifigan"
30
+
31
+
32
+ def api_inference_lfqa(model_input: str):
33
+ payload = {
34
+ "inputs": model_input,
35
+ "parameters": {
36
+ "truncation": "longest_first",
37
+ "min_length": st.session_state["min_length"],
38
+ "max_length": st.session_state["max_length"],
39
+ "do_sample": st.session_state["do_sample"],
40
+ "early_stopping": st.session_state["early_stopping"],
41
+ "num_beams": st.session_state["num_beams"],
42
+ "temperature": st.session_state["temperature"],
43
+ "top_k": None,
44
+ "top_p": None,
45
+ "no_repeat_ngram_size": 3,
46
+ "num_return_sequences": 1
47
+ },
48
+ "options": {
49
+ "wait_for_model": True
50
+ }
51
+ }
52
+ data = json.dumps(payload)
53
+ response = requests.request("POST", API_URL, headers=headers, data=data)
54
+ return json.loads(response.content.decode("utf-8"))
55
+
56
+
57
+ def inference_lfqa(model_input: str, header: dict):
58
+ payload = {
59
+ "model_input": model_input,
60
+ "parameters": {
61
+ "min_length": st.session_state["min_length"],
62
+ "max_length": st.session_state["max_length"],
63
+ "do_sample": st.session_state["do_sample"],
64
+ "early_stopping": st.session_state["early_stopping"],
65
+ "num_beams": st.session_state["num_beams"],
66
+ "temperature": st.session_state["temperature"],
67
+ "top_k": None,
68
+ "top_p": None,
69
+ "no_repeat_ngram_size": 3,
70
+ "num_return_sequences": 1
71
+ }
72
+ }
73
+ data = json.dumps(payload)
74
+ try:
75
+ response = requests.request("POST", LFQA_API_URL, headers=header, data=data)
76
+ if response.status_code == 200:
77
+ json_response = response.content.decode("utf-8")
78
+ result = json.loads(json_response)
79
+ else:
80
+ result = {"error": f"LFQA service unavailable, status code={response.status_code}"}
81
+ except requests.exceptions.RequestException as e:
82
+ result = {"error": e}
83
+ return result
84
+
85
+
86
+ def invoke_lfqa(service_backend: str, model_input: str, header: Optional[dict]):
87
+ if "HuggingFace" == service_backend:
88
+ inference_response = api_inference_lfqa(model_input)
89
+ else:
90
+ inference_response = inference_lfqa(model_input, header)
91
+ return inference_response
92
+
93
+
94
+ @st.cache(allow_output_mutation=True, show_spinner=False)
95
+ def hf_tts(text: str):
96
+ payload = {
97
+ "inputs": text,
98
+ "parameters": {
99
+ "vocoder_tag": "str_or_none(none)",
100
+ "threshold": 0.5,
101
+ "minlenratio": 0.0,
102
+ "maxlenratio": 10.0,
103
+ "use_att_constraint": False,
104
+ "backward_window": 1,
105
+ "forward_window": 3,
106
+ "speed_control_alpha": 1.0,
107
+ "noise_scale": 0.333,
108
+ "noise_scale_dur": 0.333
109
+ },
110
+ "options": {
111
+ "wait_for_model": True
112
+ }
113
+ }
114
+ data = json.dumps(payload)
115
+ response = requests.request("POST", API_URL_TTS, headers=headers, data=data)
116
+ return response.content
117
+
118
+
119
+ @st.cache(allow_output_mutation=True, show_spinner=False)
120
+ def google_tts(text: str, private_key_id: str, private_key: str, client_email: str):
121
+ config = {
122
+ "private_key_id": private_key_id,
123
+ "private_key": f"-----BEGIN PRIVATE KEY-----\n{private_key}\n-----END PRIVATE KEY-----\n",
124
+ "client_email": client_email,
125
+ "token_uri": "https://oauth2.googleapis.com/token",
126
+ }
127
+ credentials = service_account.Credentials.from_service_account_info(config)
128
+ client = texttospeech.TextToSpeechClient(credentials=credentials)
129
+
130
+ synthesis_input = texttospeech.SynthesisInput(text=text)
131
+
132
+ # Build the voice request, select the language code ("en-US") and the ssml
133
+ # voice gender ("neutral")
134
+ voice = texttospeech.VoiceSelectionParams(language_code="en-US",
135
+ ssml_gender=texttospeech.SsmlVoiceGender.NEUTRAL)
136
+
137
+ # Select the type of audio file you want returned
138
+ audio_config = texttospeech.AudioConfig(audio_encoding=texttospeech.AudioEncoding.MP3)
139
+
140
+ # Perform the text-to-speech request on the text input with the selected
141
+ # voice parameters and audio file type
142
+ response = client.synthesize_speech(input=synthesis_input, voice=voice, audio_config=audio_config)
143
+ return response
144
+
145
+
146
+ def request_context_passages(question, header):
147
+ try:
148
+ response = requests.request("GET", CONTEXT_API_URL + question, headers=header)
149
+ if response.status_code == 200:
150
+ json_response = response.content.decode("utf-8")
151
+ result = json.loads(json_response)
152
+ else:
153
+ result = {"error": f"Context passage service unavailable, status code={response.status_code}"}
154
+ except requests.exceptions.RequestException as e:
155
+ result = {"error": e}
156
+
157
+ return result
158
+
159
+
160
+ @st.cache(allow_output_mutation=True, show_spinner=False)
161
+ def get_sentence_transformer():
162
+ return SentenceTransformer('all-MiniLM-L6-v2')
163
+
164
+
165
+ @st.cache(allow_output_mutation=True, show_spinner=False)
166
+ def get_sentence_transformer_encoding(sentences):
167
+ model = get_sentence_transformer()
168
+ return model.encode([sentence for sentence in sentences], convert_to_tensor=True)
169
+
170
+
171
+ def sign_jwt() -> Dict[str, str]:
172
+ payload = {
173
+ "expires": time.time() + 6000
174
+ }
175
+ token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
176
+ return token
177
+
178
+
179
+ def extract_sentences_from_passages(passages):
180
+ sentences = []
181
+ for idx, node in enumerate(passages):
182
+ sentences.extend(tokenize.sent_tokenize(node["text"]))
183
+ return sentences
184
+
185
+
186
+ def similarity_color_picker(similarity: float):
187
+ value = int(similarity * 75)
188
+ rgb = colorsys.hsv_to_rgb(value / 300., 1.0, 1.0)
189
+ return [round(255 * x) for x in rgb]
190
+
191
+
192
+ def rgb_to_hex(rgb):
193
+ return '%02x%02x%02x' % tuple(rgb)
194
+
195
+
196
+ def similiarity_to_hex(similarity: float):
197
+ return rgb_to_hex(similarity_color_picker(similarity))
198
+
199
+
200
+ def rerank(question: str, passages: List[str], include_rank: int = 4) -> List[str]:
201
+ ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
202
+ question_passage_combinations = [[question, p["text"]] for p in passages]
203
+
204
+ # Compute the similarity scores for these combinations
205
+ similarity_scores = ce.predict(question_passage_combinations)
206
+
207
+ # Sort the scores in decreasing order
208
+ sim_ranking_idx = np.flip(np.argsort(similarity_scores))
209
+ return [passages[rank_idx] for rank_idx in sim_ranking_idx[:include_rank]]
210
+
211
+
212
+ def answer_to_context_similarity(generated_answer, context_passages, topk=3):
213
+ context_sentences = extract_sentences_from_passages(context_passages)
214
+ context_sentences_e = get_sentence_transformer_encoding(context_sentences)
215
+ answer_sentences = tokenize.sent_tokenize(generated_answer)
216
+ answer_sentences_e = get_sentence_transformer_encoding(answer_sentences)
217
+ search_result = util.semantic_search(answer_sentences_e, context_sentences_e, top_k=topk)
218
+ result = []
219
+ for idx, r in enumerate(search_result):
220
+ context = []
221
+ for idx_c in range(topk):
222
+ context.append({"source": context_sentences[r[idx_c]["corpus_id"]], "score": r[idx_c]["score"]})
223
+ result.append({"answer": answer_sentences[idx], "context": context})
224
+ return result
225
+
226
+
227
+ def post_process_answer(generated_answer):
228
+ result = generated_answer
229
+ # detect sentence boundaries regex pattern
230
+ regex = r"([A-Z][a-z].*?[.:!?](?=$| [A-Z]))"
231
+ answer_sentences = tokenize.sent_tokenize(generated_answer)
232
+ # do we have truncated last sentence?
233
+ if len(answer_sentences) > len(re.findall(regex, generated_answer)):
234
+ drop_last_sentence = " ".join(s for s in answer_sentences[:-1])
235
+ result = drop_last_sentence
236
+ return result.strip()
237
+
238
+
239
+ def format_score(value: float, precision=2):
240
+ return f"{value:.{precision}f}"
241
+
242
+
243
+ @st.cache(allow_output_mutation=True, show_spinner=False)
244
+ def get_answer(question: str):
245
+ if not question:
246
+ return {}
247
+
248
+ resp: Dict[str, str] = {}
249
+ if question and len(question.split()) > 3:
250
+ header = {"Authorization": f"Bearer {sign_jwt()}"}
251
+ context_passages = request_context_passages(question, header)
252
+ if "error" in context_passages:
253
+ resp = context_passages
254
+ else:
255
+ context_passages = rerank(question, context_passages)
256
+ conditioned_context = "<P> " + " <P> ".join([d["text"] for d in context_passages])
257
+ model_input = f'question: {question} context: {conditioned_context}'
258
+
259
+ inference_response = invoke_lfqa(st.session_state["api_lfqa_selector"], model_input, header)
260
+ if "error" in inference_response:
261
+ resp = inference_response
262
+ else:
263
+ resp["context_passages"] = context_passages
264
+ resp["answer"] = post_process_answer(inference_response[0]["generated_text"])
265
+ else:
266
+ resp = {"error": f"A longer, more descriptive question will receive a better answer. '{question}' is too short."}
267
+ return resp
268
+
269
+
270
+ def app():
271
+ with open('style.css') as f:
272
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
273
+ footer = """
274
+ <div class="footer-custom">
275
+ Streamlit app - <a href="https://www.linkedin.com/in/danijel-petkovic-573309144/" target="_blank">Danijel Petkovic</a> |
276
+ LFQA/DPR models - <a href="https://www.linkedin.com/in/blagojevicvladimir/" target="_blank">Vladimir Blagojevic</a> |
277
+ Guidance & Feedback - <a href="https://yjernite.github.io/" target="_blank">Yacine Jernite</a> |
278
+ <a href="https://towardsdatascience.com/long-form-qa-beyond-eli5-an-updated-dataset-and-approach-319cb841aabb" target="_blank">Blog</a>
279
+ </div>
280
+ """
281
+ st.markdown(footer, unsafe_allow_html=True)
282
+
283
+ st.title('Wikipedia Assistant')
284
+ st.header('We are migrating to new backend infrastructure. ETA - 15.6.2022')
285
+
286
+ #question = st.text_input(
287
+ # label='Ask Wikipedia an open-ended question below; for example, "Why do airplanes leave contrails in the sky?"')
288
+ question = ""
289
+ spinner = st.empty()
290
+ if question !="":
291
+ spinner.markdown(
292
+ f"""
293
+ <div class="loader-wrapper">
294
+ <div class="loader">
295
+ </div>
296
+ <p>Generating answer for: <b>{question}</b></p>
297
+ </div>
298
+ <label class="loader-note">Answer generation may take up to 20 sec. Please stand by.</label>
299
+ """,
300
+ unsafe_allow_html=True,
301
+ )
302
+
303
+ question_response = get_answer(question)
304
+ if question_response:
305
+ if "error" in question_response:
306
+ st.warning(question_response["error"])
307
+ else:
308
+ spinner.markdown(f"")
309
+ generated_answer = question_response["answer"]
310
+ context_passages = question_response["context_passages"]
311
+ sentence_similarity = answer_to_context_similarity(generated_answer, context_passages, topk=3)
312
+ sentences = "<div class='sentence-wrapper'>"
313
+ for item in sentence_similarity:
314
+ sentences += '<span>'
315
+ score = item["context"][0]["score"]
316
+ support_sentence = item["context"][0]["source"]
317
+ sentences += "".join([
318
+ f' {item["answer"]}',
319
+ f'<span style="background-color: #{similiarity_to_hex(score)}" class="tooltip">',
320
+ f'{format_score(score, precision=1)}',
321
+ f'<span class="tooltiptext"><b>Wikipedia source</b><br><br> {support_sentence} <br><br>Similarity: {format_score(score)}</span>'
322
+ ])
323
+ sentences += '</span>'
324
+ sentences += '</span>'
325
+ st.markdown(sentences, unsafe_allow_html=True)
326
+
327
+ with st.spinner("Generating audio..."):
328
+ if st.session_state["tts"] == "HuggingFace":
329
+ audio_file = hf_tts(generated_answer)
330
+ with open("out.flac", "wb") as f:
331
+ f.write(audio_file)
332
+ else:
333
+ audio_file = google_tts(generated_answer, st.secrets["private_key_id"],
334
+ st.secrets["private_key"], st.secrets["client_email"])
335
+ with open("out.mp3", "wb") as f:
336
+ f.write(audio_file.audio_content)
337
+
338
+ audio_file = "out.flac" if st.session_state["tts"] == "HuggingFace" else "out.mp3"
339
+ st.audio(audio_file)
340
+
341
+ st.markdown("""<hr></hr>""", unsafe_allow_html=True)
342
+
343
+ model = get_sentence_transformer()
344
+
345
+ col1, col2 = st.columns(2)
346
+
347
+ with col1:
348
+ st.subheader("Context")
349
+ with col2:
350
+ selection = st.selectbox(
351
+ label="",
352
+ options=('Paragraphs', 'Sentences', 'Answer Similarity'),
353
+ help="Context represents Wikipedia passages used to generate the answer")
354
+ question_e = model.encode(question, convert_to_tensor=True)
355
+ if selection == "Paragraphs":
356
+ sentences = extract_sentences_from_passages(context_passages)
357
+ context_e = get_sentence_transformer_encoding(sentences)
358
+ scores = util.cos_sim(question_e.repeat(context_e.shape[0], 1), context_e)
359
+ similarity_scores = scores[0].squeeze().tolist()
360
+ for idx, node in enumerate(context_passages):
361
+ node["answer_similarity"] = "{0:.2f}".format(similarity_scores[idx])
362
+ context_passages = sorted(context_passages, key=lambda x: x["answer_similarity"], reverse=True)
363
+ st.json(context_passages)
364
+ elif selection == "Sentences":
365
+ sentences = extract_sentences_from_passages(context_passages)
366
+ sentences_e = get_sentence_transformer_encoding(sentences)
367
+ scores = util.cos_sim(question_e.repeat(sentences_e.shape[0], 1), sentences_e)
368
+ sentence_similarity_scores = scores[0].squeeze().tolist()
369
+ result = []
370
+ for idx, sentence in enumerate(sentences):
371
+ result.append(
372
+ {"text": sentence, "answer_similarity": "{0:.2f}".format(sentence_similarity_scores[idx])})
373
+ context_sentences = json.dumps(sorted(result, key=lambda x: x["answer_similarity"], reverse=True))
374
+ st.json(context_sentences)
375
+ else:
376
+ st.json(sentence_similarity)