VyLala commited on
Commit
82a0c67
·
verified ·
1 Parent(s): 0de9969

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +35 -3
pipeline.py CHANGED
@@ -311,7 +311,17 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
311
  "time_cost":None,
312
  "source":links,
313
  "file_chunk":"",
314
- "file_all_output":""}
 
 
 
 
 
 
 
 
 
 
315
  if niche_cases:
316
  for niche in niche_cases:
317
  acc_score[niche] = {}
@@ -327,6 +337,8 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
327
  if pudID:
328
  id = str(pudID)
329
  saveTitle = title
 
 
330
  else:
331
  try:
332
  author_name = meta_expand["authors"].split(',')[0] # Use last name only
@@ -396,6 +408,10 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
396
  if stand_country.lower() != "not found":
397
  acc_score["country"][stand_country.lower()] = ["ncbi"]
398
  else: acc_score["country"][country.lower()] = ["ncbi"]
 
 
 
 
399
  if sample_type.lower() != "unknown":
400
  acc_score["sample_type"][sample_type.lower()] = ["ncbi"]
401
  # second way: LLM model
@@ -841,7 +857,7 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
841
  print("this is text for the last resort model")
842
  print(text)
843
 
844
- predicted_outputs, method_used, total_query_cost, more_links = await model.query_document_info(
845
  niche_cases=niche_cases,
846
  query_word=primary_word, alternative_query_word=alternative_word,
847
  saveLinkFolder = sample_folder_id,
@@ -851,7 +867,12 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
851
  print("add more links from model.query document")
852
  if more_links:
853
  links += more_links
854
- acc_score["source"] = links
 
 
 
 
 
855
  print("this is llm results: ")
856
  for pred_out in predicted_outputs:
857
  # only for country, we have to standardize
@@ -865,6 +886,9 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
865
  stand_country = standardize_location.smart_country_lookup(country.lower())
866
  if clean_country == "unknown" and stand_country.lower() == "not found":
867
  country = "unknown"
 
 
 
868
  if country.lower() != "unknown":
869
  stand_country = standardize_location.smart_country_lookup(country.lower())
870
  print("this is stand_country: ", stand_country)
@@ -874,6 +898,8 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
874
  acc_score["country"][stand_country.lower()].append(method_used + country_explanation)
875
  else:
876
  acc_score["country"][stand_country.lower()] = [method_used + country_explanation]
 
 
877
  else:
878
  if country.lower() in acc_score["country"]:
879
  if country_explanation:
@@ -882,6 +908,12 @@ async def pipeline_with_gemini(accessions,stop_flag=None, save_df=None, niche_ca
882
  else:
883
  if len(method_used + country_explanation) > 0:
884
  acc_score["country"][country.lower()] = [method_used + country_explanation]
 
 
 
 
 
 
885
  # for sample type
886
  elif pred_out == "modern/ancient/unknown":
887
  sample_type = predicted_outputs[pred_out]["answer"]
 
311
  "time_cost":None,
312
  "source":links,
313
  "file_chunk":"",
314
+ "file_all_output":"",
315
+ "signals":{ # default values
316
+ "has_geo_loc_name": False,
317
+ "has_pubmed": False,
318
+ "accession_found_in_text": False,
319
+ "predicted_country": None,
320
+ "genbank_country": None,
321
+ "num_publications": 0,
322
+ "missing_key_fields": False,
323
+ "known_failure_pattern": False,},
324
+ }
325
  if niche_cases:
326
  for niche in niche_cases:
327
  acc_score[niche] = {}
 
337
  if pudID:
338
  id = str(pudID)
339
  saveTitle = title
340
+ # save in signals that pubmed exists
341
+ acc_score["signals"]["has_pubmed"] = True
342
  else:
343
  try:
344
  author_name = meta_expand["authors"].split(',')[0] # Use last name only
 
408
  if stand_country.lower() != "not found":
409
  acc_score["country"][stand_country.lower()] = ["ncbi"]
410
  else: acc_score["country"][country.lower()] = ["ncbi"]
411
+ # write in a signals for existing country in ncbi
412
+ acc_score["signals"]["has_geo_loc_name"] = True
413
+ acc_score["signals"]["genbank_country"] = list(acc_score["country"].keys())[0]
414
+ acc_score["signals"]["num_publications"] += 1 # ncbi also counts as 1 source
415
  if sample_type.lower() != "unknown":
416
  acc_score["sample_type"][sample_type.lower()] = ["ncbi"]
417
  # second way: LLM model
 
857
  print("this is text for the last resort model")
858
  print(text)
859
 
860
+ predicted_outputs, method_used, total_query_cost, more_links, accession_found_in_text = await model.query_document_info(
861
  niche_cases=niche_cases,
862
  query_word=primary_word, alternative_query_word=alternative_word,
863
  saveLinkFolder = sample_folder_id,
 
867
  print("add more links from model.query document")
868
  if more_links:
869
  links += more_links
870
+ acc_score["source"] = links
871
+ # add into the number of publications
872
+ acc_score["signals"]["num_publication"] += len(acc_score["source"])
873
+ # add if accession_found_in_text or not
874
+ acc_score["signals"]["accession_found_in_text"] = accession_found_in_text
875
+
876
  print("this is llm results: ")
877
  for pred_out in predicted_outputs:
878
  # only for country, we have to standardize
 
886
  stand_country = standardize_location.smart_country_lookup(country.lower())
887
  if clean_country == "unknown" and stand_country.lower() == "not found":
888
  country = "unknown"
889
+ # predicted country is unknown
890
+ acc_score["signals"]["predicted_country"] = "unknown"
891
+ acc_score["signals"]["known_failure_pattern"] = True
892
  if country.lower() != "unknown":
893
  stand_country = standardize_location.smart_country_lookup(country.lower())
894
  print("this is stand_country: ", stand_country)
 
898
  acc_score["country"][stand_country.lower()].append(method_used + country_explanation)
899
  else:
900
  acc_score["country"][stand_country.lower()] = [method_used + country_explanation]
901
+ # predicted country is non unknown
902
+ acc_score["signals"]["predicted_country"] = stand_country.lower()
903
  else:
904
  if country.lower() in acc_score["country"]:
905
  if country_explanation:
 
908
  else:
909
  if len(method_used + country_explanation) > 0:
910
  acc_score["country"][country.lower()] = [method_used + country_explanation]
911
+ # predicted country is non unknown
912
+ acc_score["signals"]["predicted_country"] = country.lower()
913
+ else:
914
+ # predicted country is unknown
915
+ acc_score["signals"]["predicted_country"] = "unknown"
916
+ acc_score["signals"]["known_failure_pattern"] = True
917
  # for sample type
918
  elif pred_out == "modern/ancient/unknown":
919
  sample_type = predicted_outputs[pred_out]["answer"]