Spaces:
Runtime error
Runtime error
added beta deep research mode
Browse files- app_gradio.py +112 -12
app_gradio.py
CHANGED
|
@@ -480,6 +480,100 @@ def make_embedding_plot(papers_df, top_k, consensus_answer, arxiv_corpus=arxiv_c
|
|
| 480 |
plt.axis('off')
|
| 481 |
return fig
|
| 482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
def run_pathfinder(query, top_k, extra_keywords, toggles, prompt_type, rag_type, ec=ec, progress=gr.Progress()):
|
| 484 |
|
| 485 |
yield None, None, None, None, None
|
|
@@ -507,21 +601,26 @@ def run_pathfinder(query, top_k, extra_keywords, toggles, prompt_type, rag_type,
|
|
| 507 |
ec.hyde = True
|
| 508 |
ec.rerank = True
|
| 509 |
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
|
| 519 |
-
progress(0.6, desc="Generating consensus")
|
| 520 |
consensus_answer = evaluate_overall_consensus(query, [formatted_df['abstract'][i+1] for i in range(len(formatted_df))])
|
| 521 |
consensus = '## Consensus \n'+consensus_answer.consensus + '\n\n'+consensus_answer.explanation + '\n\n > Relevance of retrieved papers to answer: %.1f' %consensus_answer.relevance_score
|
| 522 |
yield formatted_df, rag_answer['answer'], consensus, None, None
|
| 523 |
|
| 524 |
-
progress(0.8, desc="Analyzing question type")
|
| 525 |
question_type_gen = guess_question_type(query)
|
| 526 |
if '<categorization>' in question_type_gen:
|
| 527 |
question_type_gen = question_type_gen.split('<categorization>')[1]
|
|
@@ -531,7 +630,7 @@ def run_pathfinder(query, top_k, extra_keywords, toggles, prompt_type, rag_type,
|
|
| 531 |
qn_type = question_type_gen
|
| 532 |
yield formatted_df, rag_answer['answer'], consensus, qn_type, None
|
| 533 |
|
| 534 |
-
progress(1.0, desc="Visualizing embeddings")
|
| 535 |
fig = make_embedding_plot(formatted_df, top_k, consensus_answer)
|
| 536 |
|
| 537 |
yield formatted_df, rag_answer['answer'], consensus, qn_type, fig
|
|
@@ -551,6 +650,7 @@ def create_interface():
|
|
| 551 |
with gr.Tab("pathfinder"):
|
| 552 |
with gr.Accordion("What is Pathfinder? / How do I use it?", open=False):
|
| 553 |
gr.Markdown(pathfinder_text)
|
|
|
|
| 554 |
|
| 555 |
with gr.Row():
|
| 556 |
query = gr.Textbox(label="Ask me anything")
|
|
@@ -559,7 +659,7 @@ def create_interface():
|
|
| 559 |
top_k = gr.Slider(1, 30, step=1, value=10, label="top-k", info="Number of papers to retrieve")
|
| 560 |
keywords = gr.Textbox(label="Optional Keywords (comma-separated)",value="")
|
| 561 |
toggles = gr.CheckboxGroup(["Keywords", "Time", "Citations"], label="Weight by", info="weighting retrieved papers",value=['Keywords'])
|
| 562 |
-
prompt_type = gr.Radio(choices=["Single-paper", "Multi-paper", "Bibliometric", "Broad but nuanced"], label="Prompt Specialization", value='Multi-paper')
|
| 563 |
rag_type = gr.Radio(choices=["Semantic Search", "Semantic + HyDE", "Semantic + CoHERE", "Semantic + HyDE + CoHERE"], label="RAG Method",value='Semantic + HyDE + CoHERE')
|
| 564 |
with gr.Column(scale=2, min_width=300):
|
| 565 |
img1 = gr.Image("local_files/pathfinder_logo.png")
|
|
|
|
| 480 |
plt.axis('off')
|
| 481 |
return fig
|
| 482 |
|
| 483 |
+
|
| 484 |
+
def getsmallans(query, df):
|
| 485 |
+
|
| 486 |
+
allcontent = dr_smallans_prompt
|
| 487 |
+
|
| 488 |
+
smallauth = ''
|
| 489 |
+
linkstr = ''
|
| 490 |
+
for i, row in df.iterrows():
|
| 491 |
+
# content = f"Paper {i+1}: {row['title'].replace('\n',' ')}\n{row['abstract'].replace('\n',' ')}\n\n"
|
| 492 |
+
content = f"Paper ({row['authors'][0].split(',')[0]} et al. {row['date'].year}): {row['title']}\n{row['abstract']}\n\n"
|
| 493 |
+
smallauth = smallauth + f"({row['authors'][0].split(',')[0]} et al. {row['date'].year}) "
|
| 494 |
+
linkstr = linkstr + f"[{row['authors'][0].split(',')[0]} et al. {row['date'].year}](" + row['ADS Link'].split('](')[1] + ' \n\n'
|
| 495 |
+
allcontent = allcontent + content
|
| 496 |
+
|
| 497 |
+
# allcontent = allcontent + '\n Question: '+query
|
| 498 |
+
|
| 499 |
+
gen_client = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key)
|
| 500 |
+
|
| 501 |
+
messages = [("system",allcontent,),("human", query),]
|
| 502 |
+
smallans = gen_client.invoke(messages).content
|
| 503 |
+
|
| 504 |
+
tmplnk = linkstr.split(' \n\n')
|
| 505 |
+
linkdict = {}
|
| 506 |
+
for i in range(len(tmplnk)-1):
|
| 507 |
+
linkdict[tmplnk[i].split('](')[0][1:]] = tmplnk[i]
|
| 508 |
+
|
| 509 |
+
for key in linkdict.keys():
|
| 510 |
+
try:
|
| 511 |
+
smallans = smallans.replace(key, linkdict[key])
|
| 512 |
+
key2 = key[0:-4]+'('+key[-4:]+')'
|
| 513 |
+
smallans = smallans.replace(key2, linkdict[key])
|
| 514 |
+
except:
|
| 515 |
+
print('key not found', key)
|
| 516 |
+
|
| 517 |
+
return smallans, smallauth, linkstr
|
| 518 |
+
|
| 519 |
+
def compileinfo(query, atom_qns, atom_qn_ans, atom_qn_strs):
|
| 520 |
+
|
| 521 |
+
tmp = dr_compileinfo_prompt
|
| 522 |
+
links = ''
|
| 523 |
+
for i in range(len(atom_qn_ans)):
|
| 524 |
+
tmp = tmp + atom_qns[i] + '\n\n' + atom_qn_ans[i] + '\n\n'
|
| 525 |
+
links = links + atom_qn_strs[i] + '\n\n'
|
| 526 |
+
|
| 527 |
+
gen_client = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key)
|
| 528 |
+
|
| 529 |
+
messages = [("system",tmp,),("human", query),]
|
| 530 |
+
smallans = gen_client.invoke(messages).content
|
| 531 |
+
return smallans, links
|
| 532 |
+
|
| 533 |
+
def deep_research(question, top_k, ec):
|
| 534 |
+
|
| 535 |
+
full_answer = '## ' + question
|
| 536 |
+
|
| 537 |
+
gen_client = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key)
|
| 538 |
+
messages = [("system",prompt_qdec2,),("human", question),]
|
| 539 |
+
rscope_text = gen_client.invoke(messages).content
|
| 540 |
+
|
| 541 |
+
full_answer = full_answer +' \n'+ rscope_text
|
| 542 |
+
|
| 543 |
+
rscope_messages = [("system","""In the given text, what are the main atomic questions being asked? Please answer as a concise list.""",),("human", rscope_text),]
|
| 544 |
+
rscope_qns = gen_client.invoke(rscope_messages).content
|
| 545 |
+
|
| 546 |
+
atom_qns = []
|
| 547 |
+
|
| 548 |
+
temp = rscope_qns.split('\n')
|
| 549 |
+
for i in temp:
|
| 550 |
+
if i != '':
|
| 551 |
+
atom_qns.append(i)
|
| 552 |
+
|
| 553 |
+
atom_qn_dfs = []
|
| 554 |
+
atom_qn_ans = []
|
| 555 |
+
atom_qn_strs = []
|
| 556 |
+
for i in range(len(atom_qns)):
|
| 557 |
+
rs, small_df = ec.retrieve(atom_qns[i], top_k = top_k, return_scores=True)
|
| 558 |
+
formatted_df = ec.return_formatted_df(rs, small_df)
|
| 559 |
+
atom_qn_dfs.append(formatted_df)
|
| 560 |
+
smallans, smallauth, linkstr = getsmallans(atom_qns[i], atom_qn_dfs[i])
|
| 561 |
+
|
| 562 |
+
atom_qn_ans.append(smallans)
|
| 563 |
+
atom_qn_strs.append(linkstr)
|
| 564 |
+
full_answer = full_answer +' \n### '+atom_qns[i]
|
| 565 |
+
full_answer = full_answer +' \n'+smallans
|
| 566 |
+
|
| 567 |
+
finalans, finallinks = compileinfo(question, atom_qns, atom_qn_ans, atom_qn_strs)
|
| 568 |
+
full_answer = full_answer +' \n'+'### Summary:\n'+finalans
|
| 569 |
+
|
| 570 |
+
full_df = pd.concat(atom_qn_dfs)
|
| 571 |
+
|
| 572 |
+
rag_answer = {}
|
| 573 |
+
rag_answer['answer'] = full_answer
|
| 574 |
+
|
| 575 |
+
return full_df, rag_answer
|
| 576 |
+
|
| 577 |
def run_pathfinder(query, top_k, extra_keywords, toggles, prompt_type, rag_type, ec=ec, progress=gr.Progress()):
|
| 578 |
|
| 579 |
yield None, None, None, None, None
|
|
|
|
| 601 |
ec.hyde = True
|
| 602 |
ec.rerank = True
|
| 603 |
|
| 604 |
+
if prompt_type == "Deep Research (BETA)":
|
| 605 |
+
formatted_df, rag_answer = deep_research(query, top_k = top_k, ec=ec)
|
| 606 |
+
yield formatted_df, rag_answer['answer'], None, None, None
|
| 607 |
+
|
| 608 |
+
else:
|
| 609 |
+
# progress(0.2, desc=search_text_list[np.random.choice(len(search_text_list))])
|
| 610 |
+
rs, small_df = ec.retrieve(query, top_k = top_k, return_scores=True)
|
| 611 |
+
formatted_df = ec.return_formatted_df(rs, small_df)
|
| 612 |
+
yield formatted_df, None, None, None, None
|
| 613 |
+
|
| 614 |
+
# progress(0.4, desc=gen_text_list[np.random.choice(len(gen_text_list))])
|
| 615 |
+
rag_answer = run_rag_qa(query, formatted_df, prompt_type)
|
| 616 |
+
yield formatted_df, rag_answer['answer'], None, None, None
|
| 617 |
|
| 618 |
+
# progress(0.6, desc="Generating consensus")
|
| 619 |
consensus_answer = evaluate_overall_consensus(query, [formatted_df['abstract'][i+1] for i in range(len(formatted_df))])
|
| 620 |
consensus = '## Consensus \n'+consensus_answer.consensus + '\n\n'+consensus_answer.explanation + '\n\n > Relevance of retrieved papers to answer: %.1f' %consensus_answer.relevance_score
|
| 621 |
yield formatted_df, rag_answer['answer'], consensus, None, None
|
| 622 |
|
| 623 |
+
# progress(0.8, desc="Analyzing question type")
|
| 624 |
question_type_gen = guess_question_type(query)
|
| 625 |
if '<categorization>' in question_type_gen:
|
| 626 |
question_type_gen = question_type_gen.split('<categorization>')[1]
|
|
|
|
| 630 |
qn_type = question_type_gen
|
| 631 |
yield formatted_df, rag_answer['answer'], consensus, qn_type, None
|
| 632 |
|
| 633 |
+
# progress(1.0, desc="Visualizing embeddings")
|
| 634 |
fig = make_embedding_plot(formatted_df, top_k, consensus_answer)
|
| 635 |
|
| 636 |
yield formatted_df, rag_answer['answer'], consensus, qn_type, fig
|
|
|
|
| 650 |
with gr.Tab("pathfinder"):
|
| 651 |
with gr.Accordion("What is Pathfinder? / How do I use it?", open=False):
|
| 652 |
gr.Markdown(pathfinder_text)
|
| 653 |
+
img2 = gr.Image("local_files/galaxy_worldmap_kiyer-min.png")
|
| 654 |
|
| 655 |
with gr.Row():
|
| 656 |
query = gr.Textbox(label="Ask me anything")
|
|
|
|
| 659 |
top_k = gr.Slider(1, 30, step=1, value=10, label="top-k", info="Number of papers to retrieve")
|
| 660 |
keywords = gr.Textbox(label="Optional Keywords (comma-separated)",value="")
|
| 661 |
toggles = gr.CheckboxGroup(["Keywords", "Time", "Citations"], label="Weight by", info="weighting retrieved papers",value=['Keywords'])
|
| 662 |
+
prompt_type = gr.Radio(choices=["Single-paper", "Multi-paper", "Bibliometric", "Broad but nuanced","Deep Research (BETA)"], label="Prompt Specialization", value='Multi-paper')
|
| 663 |
rag_type = gr.Radio(choices=["Semantic Search", "Semantic + HyDE", "Semantic + CoHERE", "Semantic + HyDE + CoHERE"], label="RAG Method",value='Semantic + HyDE + CoHERE')
|
| 664 |
with gr.Column(scale=2, min_width=300):
|
| 665 |
img1 = gr.Image("local_files/pathfinder_logo.png")
|