Spaces:
Running
Running
Update pipeline.py
Browse files- pipeline.py +35 -3
pipeline.py
CHANGED
|
@@ -311,7 +311,17 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
|
|
| 311 |
"time_cost":None,
|
| 312 |
"source":links,
|
| 313 |
"file_chunk":"",
|
| 314 |
-
"file_all_output":""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
if niche_cases:
|
| 316 |
for niche in niche_cases:
|
| 317 |
acc_score[niche] = {}
|
|
@@ -327,6 +337,8 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
|
|
| 327 |
if pudID:
|
| 328 |
id = str(pudID)
|
| 329 |
saveTitle = title
|
|
|
|
|
|
|
| 330 |
else:
|
| 331 |
try:
|
| 332 |
author_name = meta_expand["authors"].split(',')[0] # Use last name only
|
|
@@ -396,6 +408,10 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
|
|
| 396 |
if stand_country.lower() != "not found":
|
| 397 |
acc_score["country"][stand_country.lower()] = ["ncbi"]
|
| 398 |
else: acc_score["country"][country.lower()] = ["ncbi"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
if sample_type.lower() != "unknown":
|
| 400 |
acc_score["sample_type"][sample_type.lower()] = ["ncbi"]
|
| 401 |
# second way: LLM model
|
|
@@ -841,7 +857,7 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
|
|
| 841 |
print("this is text for the last resort model")
|
| 842 |
print(text)
|
| 843 |
|
| 844 |
-
predicted_outputs, method_used, total_query_cost, more_links = await model.query_document_info(
|
| 845 |
niche_cases=niche_cases,
|
| 846 |
query_word=primary_word, alternative_query_word=alternative_word,
|
| 847 |
saveLinkFolder = sample_folder_id,
|
|
@@ -851,7 +867,12 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
|
|
| 851 |
print("add more links from model.query document")
|
| 852 |
if more_links:
|
| 853 |
links += more_links
|
| 854 |
-
acc_score["source"] = links
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
print("this is llm results: ")
|
| 856 |
for pred_out in predicted_outputs:
|
| 857 |
# only for country, we have to standardize
|
|
@@ -865,6 +886,9 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
|
|
| 865 |
stand_country = standardize_location.smart_country_lookup(country.lower())
|
| 866 |
if clean_country == "unknown" and stand_country.lower() == "not found":
|
| 867 |
country = "unknown"
|
|
|
|
|
|
|
|
|
|
| 868 |
if country.lower() != "unknown":
|
| 869 |
stand_country = standardize_location.smart_country_lookup(country.lower())
|
| 870 |
print("this is stand_country: ", stand_country)
|
|
@@ -874,6 +898,8 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
|
|
| 874 |
acc_score["country"][stand_country.lower()].append(method_used + country_explanation)
|
| 875 |
else:
|
| 876 |
acc_score["country"][stand_country.lower()] = [method_used + country_explanation]
|
|
|
|
|
|
|
| 877 |
else:
|
| 878 |
if country.lower() in acc_score["country"]:
|
| 879 |
if country_explanation:
|
|
@@ -882,6 +908,12 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
|
|
| 882 |
else:
|
| 883 |
if len(method_used + country_explanation) > 0:
|
| 884 |
acc_score["country"][country.lower()] = [method_used + country_explanation]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 885 |
# for sample type
|
| 886 |
elif pred_out == "modern/ancient/unknown":
|
| 887 |
sample_type = predicted_outputs[pred_out]["answer"]
|
|
|
|
| 311 |
"time_cost":None,
|
| 312 |
"source":links,
|
| 313 |
"file_chunk":"",
|
| 314 |
+
"file_all_output":"",
|
| 315 |
+
"signals":{ # default values
|
| 316 |
+
"has_geo_loc_name": False,
|
| 317 |
+
"has_pubmed": False,
|
| 318 |
+
"accession_found_in_text": False,
|
| 319 |
+
"predicted_country": None,
|
| 320 |
+
"genbank_country": None,
|
| 321 |
+
"num_publications": 0,
|
| 322 |
+
"missing_key_fields": False,
|
| 323 |
+
"known_failure_pattern": False,},
|
| 324 |
+
}
|
| 325 |
if niche_cases:
|
| 326 |
for niche in niche_cases:
|
| 327 |
acc_score[niche] = {}
|
|
|
|
| 337 |
if pudID:
|
| 338 |
id = str(pudID)
|
| 339 |
saveTitle = title
|
| 340 |
+
# save in signals that pubmed exists
|
| 341 |
+
acc_score["signals"]["has_pubmed"] = True
|
| 342 |
else:
|
| 343 |
try:
|
| 344 |
author_name = meta_expand["authors"].split(',')[0] # Use last name only
|
|
|
|
| 408 |
if stand_country.lower() != "not found":
|
| 409 |
acc_score["country"][stand_country.lower()] = ["ncbi"]
|
| 410 |
else: acc_score["country"][country.lower()] = ["ncbi"]
|
| 411 |
+
# write in a signals for existing country in ncbi
|
| 412 |
+
acc_score["signals"]["has_geo_loc_name"] = True
|
| 413 |
+
acc_score["signals"]["genbank_country"] = list(acc_score["country"].keys())[0]
|
| 414 |
+
acc_score["signals"]["num_publications"] += 1 # ncbi also counts as 1 source
|
| 415 |
if sample_type.lower() != "unknown":
|
| 416 |
acc_score["sample_type"][sample_type.lower()] = ["ncbi"]
|
| 417 |
# second way: LLM model
|
|
|
|
| 857 |
print("this is text for the last resort model")
|
| 858 |
print(text)
|
| 859 |
|
| 860 |
+
predicted_outputs, method_used, total_query_cost, more_links, accession_found_in_text = await model.query_document_info(
|
| 861 |
niche_cases=niche_cases,
|
| 862 |
query_word=primary_word, alternative_query_word=alternative_word,
|
| 863 |
saveLinkFolder = sample_folder_id,
|
|
|
|
| 867 |
print("add more links from model.query document")
|
| 868 |
if more_links:
|
| 869 |
links += more_links
|
| 870 |
+
acc_score["source"] = links
|
| 871 |
+
# add into the number of publications
|
| 872 |
+
acc_score["signals"]["num_publication"] += len(acc_score["source"])
|
| 873 |
+
# add if accession_found_in_text or not
|
| 874 |
+
acc_score["signals"]["accession_found_in_text"] = accession_found_in_text
|
| 875 |
+
|
| 876 |
print("this is llm results: ")
|
| 877 |
for pred_out in predicted_outputs:
|
| 878 |
# only for country, we have to standardize
|
|
|
|
| 886 |
stand_country = standardize_location.smart_country_lookup(country.lower())
|
| 887 |
if clean_country == "unknown" and stand_country.lower() == "not found":
|
| 888 |
country = "unknown"
|
| 889 |
+
# predicted country is unknown
|
| 890 |
+
acc_score["signals"]["predicted_country"] = "unknown"
|
| 891 |
+
acc_score["signals"]["known_failure_pattern"] = True
|
| 892 |
if country.lower() != "unknown":
|
| 893 |
stand_country = standardize_location.smart_country_lookup(country.lower())
|
| 894 |
print("this is stand_country: ", stand_country)
|
|
|
|
| 898 |
acc_score["country"][stand_country.lower()].append(method_used + country_explanation)
|
| 899 |
else:
|
| 900 |
acc_score["country"][stand_country.lower()] = [method_used + country_explanation]
|
| 901 |
+
# predicted country is non unknown
|
| 902 |
+
acc_score["signals"]["predicted_country"] = stand_country.lower()
|
| 903 |
else:
|
| 904 |
if country.lower() in acc_score["country"]:
|
| 905 |
if country_explanation:
|
|
|
|
| 908 |
else:
|
| 909 |
if len(method_used + country_explanation) > 0:
|
| 910 |
acc_score["country"][country.lower()] = [method_used + country_explanation]
|
| 911 |
+
# predicted country is non unknown
|
| 912 |
+
acc_score["signals"]["predicted_country"] = country.lower()
|
| 913 |
+
else:
|
| 914 |
+
# predicted country is unknown
|
| 915 |
+
acc_score["signals"]["predicted_country"] = "unknown"
|
| 916 |
+
acc_score["signals"]["known_failure_pattern"] = True
|
| 917 |
# for sample type
|
| 918 |
elif pred_out == "modern/ancient/unknown":
|
| 919 |
sample_type = predicted_outputs[pred_out]["answer"]
|