Spaces:
Runtime error
Runtime error
Commit
Β·
1f1e9bd
1
Parent(s):
5cc7b84
don't use concat
Browse files
app.py
CHANGED
|
@@ -145,7 +145,7 @@ def init_models():
|
|
| 145 |
"question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
|
| 146 |
device=device
|
| 147 |
)
|
| 148 |
-
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-
|
| 149 |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 150 |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 151 |
return question_answerer, reranker, stop, device
|
|
@@ -211,6 +211,9 @@ st.markdown("""
|
|
| 211 |
""", unsafe_allow_html=True)
|
| 212 |
|
| 213 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
|
|
|
|
|
|
|
|
|
| 214 |
support_all = st.radio(
|
| 215 |
"Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
|
| 216 |
('yes', 'no'))
|
|
@@ -224,8 +227,8 @@ with st.expander("Settings (strictness, context limit, top hits)"):
|
|
| 224 |
use_reranking = st.radio(
|
| 225 |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
| 226 |
('yes', 'no'))
|
| 227 |
-
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300,
|
| 228 |
-
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300,
|
| 229 |
|
| 230 |
# def paraphrase(text, max_length=128):
|
| 231 |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
|
@@ -313,14 +316,24 @@ def run_query(query):
|
|
| 313 |
scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
|
| 314 |
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
| 315 |
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
| 316 |
-
|
| 317 |
else:
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
results = []
|
| 321 |
-
|
| 322 |
-
for result in model_results:
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
| 324 |
support = find_source(result['answer'], orig_docs, matched)
|
| 325 |
if not support:
|
| 326 |
continue
|
|
|
|
| 145 |
"question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
|
| 146 |
device=device
|
| 147 |
)
|
| 148 |
+
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
|
| 149 |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 150 |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 151 |
return question_answerer, reranker, stop, device
|
|
|
|
| 211 |
""", unsafe_allow_html=True)
|
| 212 |
|
| 213 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
| 214 |
+
concat_passages = st.radio(
|
| 215 |
+
"Concatenate passages as one long context?",
|
| 216 |
+
('no', 'yes'))
|
| 217 |
support_all = st.radio(
|
| 218 |
"Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
|
| 219 |
('yes', 'no'))
|
|
|
|
| 227 |
use_reranking = st.radio(
|
| 228 |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
| 229 |
('yes', 'no'))
|
| 230 |
+
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 10)
|
| 231 |
+
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 5)
|
| 232 |
|
| 233 |
# def paraphrase(text, max_length=128):
|
| 234 |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
|
|
|
| 316 |
scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
|
| 317 |
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
| 318 |
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
| 319 |
+
contexts = sorted_contexts[:context_limit]
|
| 320 |
else:
|
| 321 |
+
contexts = contexts[:context_limit]
|
| 322 |
+
|
| 323 |
+
if concat_passages == 'yes':
|
| 324 |
+
context = '\n---'.join(contexts)
|
| 325 |
+
model_results = qa_model(question=query, context=context, top_k=10)
|
| 326 |
+
else:
|
| 327 |
+
context = ['\n---\n'+ctx for ctx in contexts]
|
| 328 |
+
model_results = qa_model(question=[query]*len(contexts), context=context)
|
| 329 |
|
| 330 |
results = []
|
| 331 |
+
|
| 332 |
+
for i, result in enumerate(model_results):
|
| 333 |
+
if concat_passages == 'yes':
|
| 334 |
+
matched = matched_context(result['start'], result['end'], context)
|
| 335 |
+
else:
|
| 336 |
+
matched = matched_context(result['start'], result['end'], context[i])
|
| 337 |
support = find_source(result['answer'], orig_docs, matched)
|
| 338 |
if not support:
|
| 339 |
continue
|