VyLala commited on
Commit
06aa1bb
·
verified ·
1 Parent(s): ea8597c

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +276 -446
model.py CHANGED
@@ -17,7 +17,8 @@ import asyncio
17
  import google.generativeai as genai
18
 
19
  #genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
20
- genai.configure(api_key=os.getenv("GOOGLE_API_KEY_BACKUP"))
 
21
 
22
  import nltk
23
  from nltk.corpus import stopwords
@@ -972,7 +973,146 @@ def safe_call_llm(prompt, model="gemini-2.5-flash-lite", max_retries=5):
972
 
973
  raise RuntimeError("❌ Failed after max retries because of repeated rate limits.")
974
 
975
- async def query_document_info(niche_cases, query_word, alternative_query_word, saveLinkFolder, metadata, master_structured_lookup, faiss_index, document_chunks, llm_api_function, chunk=None, all_output=None, model_ai=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976
  """
977
  Queries the document using a hybrid approach:
978
  1. Local structured lookup (fast, cheap, accurate for known patterns).
@@ -980,453 +1120,143 @@ async def query_document_info(niche_cases, query_word, alternative_query_word, s
980
  """
981
  print("inside the model.query_doc_info")
982
  outputs, links, accession_found_in_text = {}, [], False
983
- if model_ai:
984
- if model_ai == "gemini-1.5-flash-latest":
985
- genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
986
- PRICE_PER_1K_INPUT_LLM = 0.000075 # $0.075 per 1M tokens
987
- PRICE_PER_1K_OUTPUT_LLM = 0.0003 # $0.30 per 1M tokens
988
- PRICE_PER_1K_EMBEDDING_INPUT = 0.000025 # $0.025 per 1M tokens
989
- global_llm_model_for_counting_tokens = genai.GenerativeModel("gemini-1.5-flash-latest")#('gemini-1.5-flash-latest')
990
- else:
991
- genai.configure(api_key=os.getenv("GOOGLE_API_KEY_BACKUP"))
992
- # Gemini 2.5 Flash-Lite pricing per 1,000 tokens
993
- PRICE_PER_1K_INPUT_LLM = 0.00010 # $0.10 per 1M input tokens
994
- PRICE_PER_1K_OUTPUT_LLM = 0.00040 # $0.40 per 1M output tokens
995
-
996
- # Embedding-001 pricing per 1,000 input tokens
997
- PRICE_PER_1K_EMBEDDING_INPUT = 0.00015 # $0.15 per 1M input tokens
998
- global_llm_model_for_counting_tokens = genai.GenerativeModel("gemini-2.5-flash-lite")#('gemini-1.5-flash-latest')
999
 
1000
- if metadata:
1001
- extracted_country, extracted_specific_location, extracted_ethnicity, extracted_type = metadata["country"], metadata["specific_location"], metadata["ethnicity"], metadata["sample_type"]
1002
- extracted_col_date, extracted_iso, extracted_title, extracted_features = metadata["collection_date"], metadata["isolate"], metadata["title"], metadata["all_features"]
1003
- else:
1004
- extracted_country, extracted_specific_location, extracted_ethnicity, extracted_type = "unknown", "unknown", "unknown", "unknown"
1005
- extracted_col_date, extracted_iso, extracted_title = "unknown", "unknown", "unknown"
1006
- # --- NEW: Pre-process alternative_query_word to remove '.X' suffix if present ---
1007
- if alternative_query_word:
1008
- alternative_query_word_cleaned = alternative_query_word.split('.')[0]
1009
- else:
1010
- alternative_query_word_cleaned = alternative_query_word
1011
- country_explanation, sample_type_explanation = None, None
1012
-
1013
- # Use the consolidated final_structured_entries for direct lookup
1014
- # final_structured_entries = master_structured_lookup.get('final_structured_entries', {})
1015
- # document_title = master_structured_lookup.get('document_title', 'Unknown Document Title') # Retrieve document title
1016
-
1017
- # Default values for all extracted fields. These will be updated.
1018
- method_used = 'unknown' # Will be updated based on the method that yields a result
1019
- population_code_from_sl = 'unknown' # To pass to RAG prompt if available
1020
- total_query_cost = 0
1021
- # Attempt 1: Try primary query_word (e.g., isolate name) with structured lookup
1022
- # try:
1023
- # print("try attempt 1 in model query")
1024
- # structured_info = final_structured_entries.get(query_word.upper())
1025
- # if structured_info:
1026
- # if extracted_country == 'unknown':
1027
- # extracted_country = structured_info['country']
1028
- # if extracted_type == 'unknown':
1029
- # extracted_type = structured_info['type']
1030
-
1031
- # # if extracted_ethnicity == 'unknown':
1032
- # # extracted_ethnicity = structured_info.get('ethnicity', 'unknown') # Get ethnicity from structured lookup
1033
- # # if extracted_specific_location == 'unknown':
1034
- # # extracted_specific_location = structured_info.get('specific_location', 'unknown') # Get specific_location from structured lookup
1035
- # population_code_from_sl = structured_info['population_code']
1036
- # method_used = "structured_lookup_direct"
1037
- # print(f"'{query_word}' found in structured lookup (direct match).")
1038
- # except:
1039
- # print("pass attempt 1 in model query")
1040
- # pass
1041
- # # Attempt 2: Try primary query_word with heuristic range lookup if direct fails (only if not already resolved)
1042
- # try:
1043
- # print("try attempt 2 in model query")
1044
- # if method_used == 'unknown':
1045
- # query_prefix, query_num_str = _parse_individual_code_parts(query_word)
1046
- # if query_prefix is not None and query_num_str is not None:
1047
- # try: query_num = int(query_num_str)
1048
- # except ValueError: query_num = None
1049
- # if query_num is not None:
1050
- # query_prefix_upper = query_prefix.upper()
1051
- # contiguous_ranges = master_structured_lookup.get('contiguous_ranges', defaultdict(list))
1052
- # pop_code_to_country = master_structured_lookup.get('pop_code_to_country', {})
1053
- # pop_code_to_ethnicity = master_structured_lookup.get('pop_code_to_ethnicity', {})
1054
- # pop_code_to_specific_loc = master_structured_lookup.get('pop_code_to_specific_loc', {})
1055
 
1056
- # if query_prefix_upper in contiguous_ranges:
1057
- # for start_num, end_num, pop_code_for_range in contiguous_ranges[query_prefix_upper]:
1058
- # if start_num <= query_num <= end_num:
1059
- # country_from_heuristic = pop_code_to_country.get(pop_code_for_range, 'unknown')
1060
- # if country_from_heuristic != 'unknown':
1061
- # if extracted_country == 'unknown':
1062
- # extracted_country = country_from_heuristic
1063
- # if extracted_type == 'unknown':
1064
- # extracted_type = 'modern'
1065
- # # if extracted_ethnicity == 'unknown':
1066
- # # extracted_ethnicity = pop_code_to_ethnicity.get(pop_code_for_range, 'unknown')
1067
- # # if extracted_specific_location == 'unknown':
1068
- # # extracted_specific_location = pop_code_to_specific_loc.get(pop_code_for_range, 'unknown')
1069
- # population_code_from_sl = pop_code_for_range
1070
- # method_used = "structured_lookup_heuristic_range_match"
1071
- # print(f"'{query_word}' not direct. Heuristic: Falls within range {query_prefix_upper}{start_num}-{query_prefix_upper}{end_num}.")
1072
- # break
1073
- # else:
1074
- # print(f"'{query_word}' heuristic match found, but country unknown. Will fall to RAG below.")
1075
- # except:
1076
- # print("pass attempt 2 in model query")
1077
- # pass
1078
- # # Attempt 3: If primary query_word failed all structured lookups, try alternative_query_word (cleaned)
1079
- # try:
1080
- # print("try attempt 3 in model query")
1081
- # if method_used == 'unknown' and alternative_query_word_cleaned and alternative_query_word_cleaned != query_word:
1082
- # print(f"'{query_word}' not found in structured (or heuristic). Trying alternative '{alternative_query_word_cleaned}'.")
1083
 
1084
- # # Try direct lookup for alternative word
1085
- # structured_info_alt = final_structured_entries.get(alternative_query_word_cleaned.upper())
1086
- # if structured_info_alt:
1087
- # if extracted_country == 'unknown':
1088
- # extracted_country = structured_info_alt['country']
1089
- # if extracted_type == 'unknown':
1090
- # extracted_type = structured_info_alt['type']
1091
- # # if extracted_ethnicity == 'unknown':
1092
- # # extracted_ethnicity = structured_info_alt.get('ethnicity', 'unknown')
1093
- # # if extracted_specific_location == 'unknown':
1094
- # # extracted_specific_location = structured_info_alt.get('specific_location', 'unknown')
1095
- # population_code_from_sl = structured_info_alt['population_code']
1096
- # method_used = "structured_lookup_alt_direct"
1097
- # print(f"Alternative '{alternative_query_word_cleaned}' found in structured lookup (direct match).")
1098
- # else:
1099
- # # Try heuristic lookup for alternative word
1100
- # alt_prefix, alt_num_str = _parse_individual_code_parts(alternative_query_word_cleaned)
1101
- # if alt_prefix is not None and alt_num_str is not None:
1102
- # try: alt_num = int(alt_num_str)
1103
- # except ValueError: alt_num = None
1104
- # if alt_num is not None:
1105
- # alt_prefix_upper = alt_prefix.upper()
1106
- # contiguous_ranges = master_structured_lookup.get('contiguous_ranges', defaultdict(list))
1107
- # pop_code_to_country = master_structured_lookup.get('pop_code_to_country', {})
1108
- # pop_code_to_ethnicity = master_structured_lookup.get('pop_code_to_ethnicity', {})
1109
- # pop_code_to_specific_loc = master_structured_lookup.get('pop_code_to_specific_loc', {})
1110
- # if alt_prefix_upper in contiguous_ranges:
1111
- # for start_num, end_num, pop_code_for_range in contiguous_ranges[alt_prefix_upper]:
1112
- # if start_num <= alt_num <= end_num:
1113
- # country_from_heuristic_alt = pop_code_to_country.get(pop_code_for_range, 'unknown')
1114
- # if country_from_heuristic_alt != 'unknown':
1115
- # if extracted_country == 'unknown':
1116
- # extracted_country = country_from_heuristic_alt
1117
- # if extracted_type == 'unknown':
1118
- # extracted_type = 'modern'
1119
- # # if extracted_ethnicity == 'unknown':
1120
- # # extracted_ethnicity = pop_code_to_ethnicity.get(pop_code_for_range, 'unknown')
1121
- # # if extracted_specific_location == 'unknown':
1122
- # # extracted_specific_location = pop_code_to_specific_loc.get(pop_code_for_range, 'unknown')
1123
- # population_code_from_sl = pop_code_for_range
1124
- # method_used = "structured_lookup_alt_heuristic_range_match"
1125
- # break
1126
- # else:
1127
- # print(f"Alternative '{alternative_query_word_cleaned}' heuristic match found, but country unknown. Will fall to RAG below.")
1128
- # except:
1129
- # print("pass attempt 3 in model query")
1130
- # pass
1131
- # use the context_for_llm to detect present_ancient before using llm model
1132
- # retrieved_chunks_text = []
1133
- # if document_chunks:
1134
- # for idx in range(len(document_chunks)):
1135
- # retrieved_chunks_text.append(document_chunks[idx])
1136
- # context_for_llm = ""
1137
- # all_context = "\n".join(retrieved_chunks_text) #
1138
- # listOfcontexts = {"chunk": chunk,
1139
- # "all_output": all_output,
1140
- # "document_chunk": all_context}
1141
- # label, context_for_llm = chooseContextLLM(listOfcontexts, query_word)
1142
- # if not context_for_llm:
1143
- # label, context_for_llm = chooseContextLLM(listOfcontexts, alternative_query_word_cleaned)
1144
- # if not context_for_llm:
1145
- # context_for_llm = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + extracted_features
1146
- # if context_for_llm:
1147
- # extracted_type, explain = mtdna_classifier.detect_ancient_flag(context_for_llm)
1148
- # extracted_type = extracted_type.lower()
1149
- # sample_type_explanation = explain
1150
- # 5. Execute RAG if needed (either full RAG or targeted RAG for missing fields)
1151
-
1152
- # Determine if a RAG call is necessary
1153
- # run_rag = (extracted_country == 'unknown' or extracted_type == 'unknown')# or \
1154
- # #extracted_ethnicity == 'unknown' or extracted_specific_location == 'unknown')
1155
- run_rag = True
1156
- if run_rag:
1157
- print("try run rag")
1158
- context_for_llm = ""
1159
- # Determine the phrase for LLM query
1160
- rag_query_phrase = ""
1161
- if query_word.lower() != "unknown":
1162
- rag_query_phrase += f"the mtDNA isolate name '{query_word}'"
1163
- # Accession number (alternative_query_word)
1164
- if (
1165
- alternative_query_word_cleaned
1166
- and alternative_query_word_cleaned != query_word
1167
- and alternative_query_word_cleaned.lower() != "unknown"
1168
- ):
1169
- if rag_query_phrase:
1170
- rag_query_phrase += f" or its accession number '{alternative_query_word_cleaned}'"
1171
- else:
1172
- rag_query_phrase += f"the accession number '{alternative_query_word_cleaned}'"
1173
-
1174
- # Construct a more specific semantic query phrase for embedding if structured info is available
1175
- semantic_query_for_embedding = rag_query_phrase # Default
1176
-
1177
- prompt_instruction_prefix = ""
1178
- output_format_str = ""
1179
-
1180
- # Determine if it's a full RAG or targeted RAG scenario based on what's already extracted
1181
- is_full_rag_scenario = True#(extracted_country == 'unknown')
1182
-
1183
- if is_full_rag_scenario: # Full RAG scenario
1184
- output_format_str = "country_name, modern/ancient/unknown"#, ethnicity, specific_location/unknown"
1185
- explain_list = "country or sample type (modern/ancient)"
1186
- if niche_cases:
1187
- output_format_str += ", "+ ", ".join(niche_cases)# "ethnicity, specific_location/unknown"
1188
- explain_list += " or "+ " or ".join(niche_cases)
1189
- method_used = "rag_llm"
1190
- print(f"Proceeding to FULL RAG for {rag_query_phrase}.")
1191
-
1192
- current_embedding_cost = 0
1193
-
1194
- print("direct to llm")
1195
- listOfcontexts = {"chunk": chunk,
1196
- "all_output": all_output,
1197
- "document_chunk": chunk}
1198
- label, context_for_llm = chooseContextLLM(listOfcontexts, query_word)
1199
- if not context_for_llm:
1200
- label, context_for_llm = chooseContextLLM(listOfcontexts, alternative_query_word_cleaned)
1201
- if not context_for_llm:
1202
- context_for_llm = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + extracted_features
1203
-
1204
- if len(context_for_llm) > 1000*1000:
1205
- context_for_llm = context_for_llm[:900000]
1206
-
1207
- # fix the prompt better:
1208
- # firstly clarify more by saying which type of organism, prioritize homo sapiens
1209
- features = metadata["all_features"]
1210
- organism = "general"
1211
- if features != "unknown":
1212
- if "organism" in features:
1213
- try:
1214
- organism = features.split("organism: ")[1].split("\n")[0]
1215
- except:
1216
- organism = features.replace("\n","; ")
1217
-
1218
- niche_prompt = ""
1219
- if niche_cases:
1220
- fields_list = ", ".join(niche_cases)
1221
- niche_prompt = (
1222
- f"Also, extract {fields_list}. "
1223
- f"If not explicitly stated, infer the most specific related or contextually relevant value. "
1224
- f"If no information is found, write 'unknown'. "
1225
- )
1226
- prompt_for_llm = (
1227
- f"{prompt_instruction_prefix}"
1228
- f"Given the following text snippets, analyze the entity/concept {rag_query_phrase} "
1229
- f"or the mitochondrial DNA sample in {organism} if these identifiers are not explicitly found. "
1230
- f"Identify its **primary associated geographic location**, preferring the most specific available: "
1231
- f"first try to determine the exact country; if no country is explicitly mentioned, then provide "
1232
- f"the next most specific region, continent, island, or other clear geographic area mentioned. "
1233
- f"If no geographic clues at all are present, state 'unknown' for location. "
1234
- f"Also, determine if the genetic sample is from a 'modern' (present-day living individual) "
1235
- f"or 'ancient' (prehistoric/archaeological) source. "
1236
- f"If the text does not specify ancient or archaeological context, assume 'modern'. "
1237
- f"{niche_prompt}"
1238
- f"Provide only {output_format_str}. "
1239
- f"If any information is not explicitly present, use the fallback rules above before defaulting to 'unknown'. "
1240
- f"For each non-'unknown' field in {explain_list}, write one sentence explaining how it was inferred from the text "
1241
- f"(one sentence for each). "
1242
- f"Format your answer so that:\n"
1243
- f"1. The **first line** contains only the {output_format_str} values separated by commas.\n"
1244
- f"2. The **second line onward** contains the explanations based on the order of the non-unknown {output_format_str} answer.\n"
1245
- f"\nText Snippets:\n{context_for_llm}")
1246
 
1247
- print("this is prompt: ", prompt_for_llm)
1248
- # check if accession in text or not
1249
- if alternative_query_word_cleaned.lower() in prompt_for_llm.lower():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1250
  accession_found_in_text = True
1251
-
1252
- if model_ai:
1253
- print("back up to ", model_ai)
1254
- #llm_response_text, model_instance = call_llm_api(prompt_for_llm, model=model_ai)
1255
- llm_response_text, model_instance = safe_call_llm(prompt_for_llm, model=model_ai)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1256
  else:
1257
- print("still 2.5 flash gemini")
1258
- llm_response_text, model_instance = safe_call_llm(prompt_for_llm)
1259
- #llm_response_text, model_instance = call_llm_api(prompt_for_llm)
1260
- print("\n--- DEBUG INFO FOR RAG ---")
1261
- print("Retrieved Context Sent to LLM (first 500 chars):")
1262
- print(context_for_llm[:500] + "..." if len(context_for_llm) > 500 else context_for_llm)
1263
- print("\nRaw LLM Response:")
1264
- print(llm_response_text)
1265
- print("--- END DEBUG INFO ---")
1266
-
1267
- llm_cost = 0
1268
- if model_instance:
1269
- try:
1270
- input_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(prompt_for_llm).total_tokens
1271
- output_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(llm_response_text).total_tokens
1272
- print(f" DEBUG: LLM Input tokens: {input_llm_tokens}")
1273
- print(f" DEBUG: LLM Output tokens: {output_llm_tokens}")
1274
- llm_cost = (input_llm_tokens / 1000) * PRICE_PER_1K_INPUT_LLM + \
1275
- (output_llm_tokens / 1000) * PRICE_PER_1K_OUTPUT_LLM
1276
- print(f" DEBUG: Estimated LLM cost: ${llm_cost:.6f}")
1277
- except Exception as e:
1278
- print(f" DEBUG: Error counting LLM tokens: {e}")
1279
- llm_cost = 0
1280
-
1281
- total_query_cost += current_embedding_cost + llm_cost
1282
- print(f" DEBUG: Total estimated cost for this RAG query: ${total_query_cost:.6f}")
1283
-
1284
- metadata_list = parse_multi_sample_llm_output(llm_response_text, output_format_str)
1285
-
1286
- print(metadata_list)
1287
- again_output_format, general_knowledge_prompt = "", ""
1288
- # if at least 1 answer is unknown, then do smart queries to get more sources besides doi
1289
- unknown_count = sum(1 for v in metadata_list.values() if v.get("answer").lower() == "unknown")
1290
- if unknown_count >= 1:
1291
- print("at least 1 unknown outputs")
1292
- out_links = {}
1293
- iso, acc = query_word, alternative_query_word
1294
- meta_expand = smart_fallback.fetch_ncbi(acc)
1295
- tem_links = smart_fallback.smart_google_search(acc, meta_expand)
1296
- tem_links = pipeline.unique_preserve_order(tem_links)
1297
- print("this is tem links with acc: ", tem_links)
1298
- # filter the quality link
1299
- print("start the smart filter link")
1300
- #success_process, output_process = run_with_timeout(smart_fallback.filter_links_by_metadata,args=(tem_links,saveLinkFolder),kwargs={"accession":acc},timeout=90)
1301
- output_process = await smart_fallback.async_filter_links_by_metadata(
1302
- tem_links, saveLinkFolder, accession=acc
1303
- )
1304
-
1305
- if output_process:
1306
- out_links.update(output_process)
1307
- print("yeah we have out_link and len: ", len(out_links))
1308
- print("yes succeed for smart filter link")
1309
- links += list(out_links.keys())
1310
- print("link keys: ", links)
1311
- if links:
1312
- tasks = [
1313
- pipeline.process_link_chunk_allOutput(link, iso, acc, saveLinkFolder, out_links, all_output, chunk)
1314
- for link in links
1315
- ]
1316
- print(f"Number of tasks to gather: {len(tasks)}")
1317
- try:
1318
- #results = await asyncio.gather(*tasks)
1319
- results = await asyncio.gather(*tasks, return_exceptions=True)
1320
- print(f"Results: {results}")
1321
- for result in results:
1322
- if isinstance(result, Exception):
1323
- print(f"Error in task: {result}")
1324
- else:
1325
- print(f"Task completed successfully")
1326
-
1327
- print("Finished gathering results")
1328
- except Exception as e:
1329
- print(f"Error in gathering: {e}")
1330
-
1331
- #results = await asyncio.gather(*tasks)
1332
- # combine results
1333
- print("get results for context_for_llm")
1334
- for context, new_all_output, new_chunk in results:
1335
- print("inside new results")
1336
- context_for_llm += new_all_output
1337
- context_for_llm += new_chunk
1338
- print("len of context after merge all: ", len(context_for_llm))
1339
-
1340
- if len(context_for_llm) > 750000:
1341
- context_for_llm = data_preprocess.normalize_for_overlap(context_for_llm)
1342
- if len(context_for_llm) > 750000:
1343
- # use build context for llm function to reduce token
1344
- texts_reduce = []
1345
- out_links_reduce = {}
1346
- reduce_context_for_llm = ""
1347
- if links:
1348
- for link in links:
1349
- all_output_reduce, chunk_reduce, context_reduce = "", "",""
1350
- context_reduce, all_output_reduce, chunk_reduce = await pipeline.process_link_chunk_allOutput(link,
1351
- iso, acc, saveLinkFolder, out_links_reduce,
1352
- all_output_reduce, chunk_reduce)
1353
- texts_reduce.append(all_output_reduce)
1354
- out_links_reduce[link] = {"all_output": all_output_reduce}
1355
- input_prompt = ["country_name", "modern/ancient/unknown"]
1356
- if niche_cases: input_prompt += niche_cases
1357
- reduce_context_for_llm = data_preprocess.build_context_for_llm(texts_reduce, acc, input_prompt)
1358
- if reduce_context_for_llm:
1359
- print("reduce context for llm")
1360
- context_for_llm = reduce_context_for_llm
1361
- else:
1362
- print("no reduce context for llm despite>1M")
1363
- context_for_llm = context_for_llm[:250000]
1364
-
1365
- for key in metadata_list:
1366
- answer = metadata_list[key]["answer"]
1367
- if answer.lower() in " ".join(["unknown", "unspecified","could not get response from llm api.", "undefined"]):
1368
- print("have to do again")
1369
- again_output_format = key
1370
- print("output format:", again_output_format)
1371
- general_knowledge_prompt = (
1372
- f"{prompt_instruction_prefix}"
1373
- f"Given the following text snippets, analyze the entity/concept {rag_query_phrase} "
1374
- f"or the mitochondrial DNA sample in {organism} if these identifiers are not explicitly found. "
1375
- f"Identify and extract {again_output_format}"
1376
- f"If not explicitly stated, infer the most specific related or contextually relevant value. "
1377
- f"If no information is found, write 'unknown'. "
1378
- f"Provide only {again_output_format}. "
1379
- f"For non-'unknown' field in {again_output_format}, write one sentence explaining how it was inferred from the text "
1380
- f"Format your answer so that:\n"
1381
- f"1. The **first line** contains only the {again_output_format} answer.\n"
1382
- f"2. The **second line onward** contains the explanations based on the non-unknown {again_output_format} answer.\n"
1383
- f"\nText Snippets:\n{context_for_llm}")
1384
- print("len of prompt:", len(general_knowledge_prompt))
1385
- if alternative_query_word_cleaned.lower() in general_knowledge_prompt.lower():
1386
- accession_found_in_text = True
1387
- if general_knowledge_prompt:
1388
- if model_ai:
1389
- print("back up to ", model_ai)
1390
- llm_response_text, model_instance = safe_call_llm(general_knowledge_prompt, model=model_ai)
1391
- #llm_response_text, model_instance = call_llm_api(general_knowledge_prompt, model=model_ai)
1392
- else:
1393
- print("still 2.5 flash gemini")
1394
- llm_response_text, model_instance = safe_call_llm(general_knowledge_prompt)
1395
- #llm_response_text, model_instance = call_llm_api(general_knowledge_prompt)
1396
- print("\n--- DEBUG INFO FOR RAG ---")
1397
- print("Retrieved Context Sent to LLM (first 500 chars):")
1398
- print(context_for_llm[:500] + "..." if len(context_for_llm) > 500 else context_for_llm)
1399
- print("\nRaw LLM Response:")
1400
- print(llm_response_text)
1401
- print("--- END DEBUG INFO ---")
1402
-
1403
- llm_cost = 0
1404
- if model_instance:
1405
- try:
1406
- input_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(prompt_for_llm).total_tokens
1407
- output_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(llm_response_text).total_tokens
1408
- print(f" DEBUG: LLM Input tokens: {input_llm_tokens}")
1409
- print(f" DEBUG: LLM Output tokens: {output_llm_tokens}")
1410
- llm_cost = (input_llm_tokens / 1000) * PRICE_PER_1K_INPUT_LLM + \
1411
- (output_llm_tokens / 1000) * PRICE_PER_1K_OUTPUT_LLM
1412
- print(f" DEBUG: Estimated LLM cost: ${llm_cost:.6f}")
1413
- except Exception as e:
1414
- print(f" DEBUG: Error counting LLM tokens: {e}")
1415
- llm_cost = 0
1416
-
1417
- total_query_cost += current_embedding_cost + llm_cost
1418
- print("total query cost in again: ", total_query_cost)
1419
- metadata_list_one_case = parse_multi_sample_llm_output(llm_response_text, again_output_format)
1420
- print("metadata list after running again unknown output: ", metadata_list)
1421
- for key in metadata_list_one_case:
1422
- print("keys of outputs: ", outputs.keys())
1423
- if key not in list(outputs.keys()):
1424
- print("this is key and about to be added into outputs: ", key)
1425
- outputs[key] = metadata_list_one_case[key]
1426
- else:
1427
- outputs[key] = metadata_list[key]
1428
-
1429
- print("all done and method used: ", outputs, method_used)
1430
- print("total cost: ", total_query_cost)
1431
- return outputs, method_used, total_query_cost, links, accession_found_in_text
1432
-
 
17
  import google.generativeai as genai
18
 
19
  #genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
20
+ #genai.configure(api_key=os.getenv("GOOGLE_API_KEY_BACKUP"))
21
+ genai.configure(api_key=os.getenv("NEW_GOOGLE_API_KEY"))
22
 
23
  import nltk
24
  from nltk.corpus import stopwords
 
973
 
974
  raise RuntimeError("❌ Failed after max retries because of repeated rate limits.")
975
 
976
+ def outputs_from_multiPrompts(raw_response: str, output_format_str, acc_prompts):
977
+ # Split the text based on the pattern '**Prompt X:'
978
+ raw_response = re.split(r'\*\*Prompt \d+:', text)
979
+
980
+ # Remove any empty sections from the split list
981
+ prompts = [prompt.strip() for prompt in raw_response if prompt.strip()]
982
+
983
+ # Create a list of output strings
984
+ outputs = {}
985
+ accs = list(acc_prompts.keys())
986
+ # Loop through the prompts and combine the header and body
987
+ for i in range(0, len(prompts)):
988
+ prompt_header = prompts[i].strip() # This is the "USA, unknown, Venezuela" or similar part
989
+ prompt_header = re.sub(r'^\*\*\n', '', prompt_header) # Remove any leading '**\n'
990
+ accession, output = accs[i], ""
991
+ if i + 1 < len(prompts): # Check if there is a next body text
992
+ prompt_body = prompts[i + 1].strip() # This is the body of the response
993
+ # Remove any unwanted '**\n' before the prompt content
994
+ output = f"{prompt_header}\n\n{prompt_body}"
995
+ else:
996
+ # If no body exists, add only the header (though this case shouldn't occur in this example)
997
+ output = f"{prompt_header}\n\n"
998
+ metadata_list = parse_multi_sample_llm_output(output, output_format_str)
999
+ outputs[accession] = metadata_list
1000
+ return outputs
1001
+
1002
+ def multi_prompts(dictsAccs, output_format_str, niche_cases=None, prompt_template="default"):
1003
+ prompts = {}
1004
+ """dictsAccs = {
1005
+ "acc1": "text1",
1006
+ "acc2": "text2",
1007
+ "acc3": "text3" }"""
1008
+ if niche_cases:
1009
+ fields_list = ", ".join(niche_cases)
1010
+ niche_prompt = (
1011
+ f"Also, extract {fields_list}. "
1012
+ f"If not explicitly stated, infer the most specific related or contextually relevant value. "
1013
+ f"If no information is found, write 'unknown'. "
1014
+ )
1015
+ #output_format_str += ", " + ", ".join(niche_cases)
1016
+ else: niche_prompt = ""
1017
+ for acc_pos in range(len(list(dictsAccs.keys()))):
1018
+ acc = list(dictsAccs.keys())[acc_pos]
1019
+ if acc:
1020
+ acc_cleaned = acc.split('.')[0]
1021
+ else:
1022
+ acc_cleaned = acc
1023
+ accession_found_in_text = False
1024
+ context_for_llm = dictsAccs[acc]
1025
+ if prompt_template == "default":
1026
+ prompt_for_llm = (
1027
+ f"Prompt {acc_pos+1}: "
1028
+ f"Given the following text snippets, analyze the entity/concept of this accession number {acc_cleaned} "
1029
+ #f"or the mitochondrial DNA sample if these identifiers are not explicitly found. "
1030
+ f"Identify its **primary associated geographic location**, preferring the most specific available: "
1031
+ f"first try to determine the exact country; if no country is explicitly mentioned, then provide "
1032
+ f"the next most specific region, continent, island, or other clear geographic area mentioned. "
1033
+ f"If no geographic clues at all are present, state 'unknown' for location. "
1034
+ f"Also, determine if the genetic sample is from a 'modern' (present-day living individual) "
1035
+ f"or 'ancient' (prehistoric/archaeological) source. "
1036
+ f"If the text does not specify ancient or archaeological context, assume 'modern'. "
1037
+ f"{niche_prompt}"
1038
+ f"Provide only {output_format_str}. "
1039
+ f"If any information is not explicitly present, use the fallback rules above before defaulting to 'unknown'. "
1040
+ f"For each non-'unknown' field, write one sentence explaining how it was inferred from the text "
1041
+ f"(one sentence for each). "
1042
+ f"Format your answer so that:\n"
1043
+ f"1. The **first line** contains only the {output_format_str} values separated by commas.\n"
1044
+ f"2. The **second line onward** contains the explanations based on the order of the non-unknown {output_format_str} answer.\n"
1045
+ f"\nText Snippets:\n{context_for_llm}")
1046
+ # check if accession in text or not
1047
+ if acc_cleaned.lower() in context_for_llm.lower():
1048
+ accession_found_in_text = True
1049
+ # save values in prompts:
1050
+ prompts[acc] = [prompt_for_llm, accession_found_in_text]
1051
+ return prompts
1052
+
1053
+ async def getMoreInfoForAcc(iso=None, acc=None, saveLinkFolder=None, niche_cases=None, limit_context=250000):
1054
+ linksWithTexts, links, context_for_llm = {}, [], ""
1055
+ meta_expand = smart_fallback.fetch_ncbi(acc)
1056
+ raw_tem_links = smart_fallback.smart_google_search(acc, meta_expand)
1057
+ tem_links = pipeline.unique_preserve_order(raw_tem_links)
1058
+ print("this is tem links with acc: ", tem_links)
1059
+ # filter the quality link
1060
+ print("start the smart filter link")
1061
+ #success_process, output_process = run_with_timeout(smart_fallback.filter_links_by_metadata,args=(tem_links,saveLinkFolder),kwargs={"accession":acc},timeout=90)
1062
+ output_process = await smart_fallback.async_filter_links_by_metadata(
1063
+ tem_links, saveLinkFolder, accession=acc
1064
+ )
1065
+ print('inside getMoreInfoForAcc and here is outputProcess: ', output_process)
1066
+ if output_process:
1067
+ linksWithTexts.update(output_process)
1068
+ print("yeah we have linksWithTexts and len: ", len(linksWithTexts))
1069
+ print("yes succeed for smart filter link")
1070
+ links += list(linksWithTexts.keys())
1071
+ print("link keys: ", links)
1072
+ else:
1073
+ print("not have output_process")
1074
+ links += tem_links
1075
+ if links:
1076
+ # use build context for llm function to reduce token
1077
+ texts_reduce = []
1078
+ linksWithTexts_reduce = {}
1079
+ reduce_context_for_llm = ""
1080
+ print("links:", links)
1081
+ for link in links:
1082
+ print("link: ", link)
1083
+ new_all_output = await pipeline.process_link_allOutput(link,
1084
+ iso, acc, saveLinkFolder, linksWithTexts_reduce, context_for_llm)
1085
+ print("done all output")
1086
+ context_for_llm += new_all_output
1087
+ texts_reduce.append(new_all_output)
1088
+ linksWithTexts_reduce[link] = {"all_output": new_all_output}
1089
+ # tasks = [
1090
+ # pipeline.process_link_allOutput(link, iso, acc, saveLinkFolder, linksWithTexts, all_output)
1091
+ # for link in links
1092
+ # ]
1093
+ # results = await asyncio.gather(*tasks)
1094
+ # print("this is result:", results)
1095
+ # # combine results
1096
+ # for new_all_output in results:
1097
+ # context_for_llm += new_all_output
1098
+ print("len of context after merge all: ", len(context_for_llm))
1099
+
1100
+ if len(context_for_llm) > 500000:
1101
+ context_for_llm = data_preprocess.normalize_for_overlap(context_for_llm)
1102
+ if len(context_for_llm) > 500000:
1103
+ if links:
1104
+ input_prompt = ["country_name", "modern/ancient/unknown"]
1105
+ if niche_cases: input_prompt += niche_cases
1106
+ reduce_context_for_llm = data_preprocess.build_context_for_llm(texts_reduce, acc, input_prompt, limit_context)
1107
+ if reduce_context_for_llm:
1108
+ print("reduce context for llm")
1109
+ context_for_llm = reduce_context_for_llm
1110
+ else:
1111
+ print("no reduce context for llm despite>1M")
1112
+ context_for_llm = context_for_llm[:limit_context]
1113
+ return context_for_llm, linksWithTexts, links
1114
+
1115
+ async def query_document_info(niche_cases, saveLinkFolder, llm_api_function, prompts):
1116
  """
1117
  Queries the document using a hybrid approach:
1118
  1. Local structured lookup (fast, cheap, accurate for known patterns).
 
1120
  """
1121
  print("inside the model.query_doc_info")
1122
  outputs, links, accession_found_in_text = {}, [], False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1123
 
1124
+ genai.configure(api_key=os.getenv("NEW_GOOGLE_API_KEY"))
1125
+ # Gemini 2.5 Flash-Lite pricing per 1,000 tokens
1126
+ PRICE_PER_1K_INPUT_LLM = 0.00010 # $0.10 per 1M input tokens
1127
+ PRICE_PER_1K_OUTPUT_LLM = 0.00040 # $0.40 per 1M output tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1128
 
1129
+ # Embedding-001 pricing per 1,000 input tokens
1130
+ PRICE_PER_1K_EMBEDDING_INPUT = 0.00015 # $0.15 per 1M input tokens
1131
+ global_llm_model_for_counting_tokens = genai.GenerativeModel("gemini-2.5-flash-lite")#('gemini-1.5-flash-latest')
1132
+
1133
+ # Determine fields to ask LLM for and output format based on what's known/needed
1134
+ output_format_str = "country_name, modern/ancient/unknown"
1135
+ method_used = 'rag_llm' # Will be updated based on the method that yields a result
1136
+ if niche_cases:
1137
+ output_format_str += ", " + ", ".join(niche_cases)
1138
+ # Calculate embedding cost for the primary query word
1139
+ total_query_cost, current_embedding_cost = 0, 0
1140
+ created_prompts = multi_prompts(prompts, output_format_str, niche_cases=niche_cases, prompt_template="default")
1141
+ print("done create prompt and length: ", len(created_prompts))
1142
+ prompt_for_llm = []
1143
+ for acc in created_prompts:
1144
+ outputs[acc] = {"predicted_output":"",
1145
+ "method_used": method_used,
1146
+ "total_query_cost":None,
1147
+ "links": [],
1148
+ "accession_found_in_text":created_prompts[acc][1],
1149
+ }
1150
+ prompt_for_llm.append(created_prompts[acc][0])
 
 
 
 
 
1151
 
1152
+ prompt_for_llm = "\n".join(prompt_for_llm) #there is only 1 prompt created #+ "\n" + "Give answer for each prompt"
1153
+ print("length of prompt: ", len(prompt_for_llm))
1154
+ print("use 2.5 flash gemini")
1155
+ llm_response_text, model_instance = call_llm_api(prompt_for_llm)
1156
+ print("\n--- DEBUG INFO FOR RAG ---")
1157
+ print("Retrieved Context Sent to LLM (first 500 chars):")
1158
+ print(prompt_for_llm[:500] + "..." if len(prompt_for_llm) > 500 else prompt_for_llm)
1159
+ print("\nRaw LLM Response:")
1160
+ print(llm_response_text)
1161
+ print("--- END DEBUG INFO ---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1162
 
1163
+ llm_cost = 0
1164
+ if model_instance:
1165
+ try:
1166
+ input_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(prompt_for_llm).total_tokens
1167
+ output_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(llm_response_text).total_tokens
1168
+ print(f" DEBUG: LLM Input tokens: {input_llm_tokens}")
1169
+ print(f" DEBUG: LLM Output tokens: {output_llm_tokens}")
1170
+ llm_cost = (input_llm_tokens / 1000) * PRICE_PER_1K_INPUT_LLM + \
1171
+ (output_llm_tokens / 1000) * PRICE_PER_1K_OUTPUT_LLM
1172
+ print(f" DEBUG: Estimated LLM cost: ${llm_cost:.6f}")
1173
+ except Exception as e:
1174
+ print(f" DEBUG: Error counting LLM tokens: {e}")
1175
+ llm_cost = 0
1176
+
1177
+ total_query_cost += current_embedding_cost + llm_cost
1178
+ print(f" DEBUG: Total estimated cost for this RAG query: ${total_query_cost:.6f}")
1179
+
1180
+ metadata_list = parse_multi_sample_llm_output(llm_response_text, output_format_str)
1181
+ multi_metadata_lists = [metadata_list]
1182
+ list_accs = list(prompts.keys())
1183
+ if acc:
1184
+ acc_cleaned = acc.split(".")[0]
1185
+ else: acc_cleaned = acc
1186
+ for metadata_list_pos in range(len(multi_metadata_lists)):
1187
+ metadata_list = multi_metadata_lists[metadata_list_pos]
1188
+ print(metadata_list)
1189
+ acc = list_accs[metadata_list_pos]
1190
+ again_output_format, general_knowledge_prompt = "", ""
1191
+ output_acc = {}
1192
+ # if at least 1 answer is unknown, then do smart queries to get more sources besides doi
1193
+ unknown_count = sum(1 for v in metadata_list.values() if v.get("answer").lower() == "unknown")
1194
+ if unknown_count >= 1:
1195
+ print("at least 1 unknown outputs")
1196
+ context_for_llm, linksWithTexts, more_links = await getMoreInfoForAcc(iso=None, acc=acc, saveLinkFolder=saveLinkFolder, niche_cases=niche_cases, limit_context=250000)
1197
+ links += more_links
1198
+ if acc_cleaned.lower() in context_for_llm.lower():
1199
  accession_found_in_text = True
1200
+ # update again accession found in text due to new context for llm
1201
+ outputs[acc]["accession_found_in_text"] = accession_found_in_text
1202
+ # update links for output of acc
1203
+ outputs[acc]["links"] = links
1204
+ else:
1205
+ context_for_llm = prompts[acc]
1206
+ for key in metadata_list:
1207
+ answer = metadata_list[key]["answer"]
1208
+ if answer.lower() in " ".join(["unknown", "unspecified","could not get response from llm api.", "undefined"]):
1209
+ print("have to do again")
1210
+ again_output_format = key
1211
+ print("output format:", again_output_format)
1212
+ general_knowledge_prompt = (
1213
+ f"Given the following text snippets, analyze the entity/concept of this accession number {acc_cleaned} "
1214
+ #f"or the mitochondrial DNA sample if these identifiers are not explicitly found. "
1215
+ f"Identify and extract {again_output_format}"
1216
+ f"If not explicitly stated, infer the most specific related or contextually relevant value. "
1217
+ f"If no information is found, write 'unknown'. "
1218
+ f"Provide only {again_output_format}. "
1219
+ f"For non-'unknown' field in {again_output_format}, write one sentence explaining how it was inferred from the text "
1220
+ f"Format your answer so that:\n"
1221
+ f"1. The **first line** contains only the {again_output_format} answer.\n"
1222
+ f"2. The **second line onward** contains the explanations based on the non-unknown {again_output_format} answer.\n"
1223
+ f"\nText Snippets:\n{context_for_llm}")
1224
+ print("len of general prompt:", len(general_knowledge_prompt))
1225
+ if general_knowledge_prompt:
1226
+ print("use 2.5 flash gemini")
1227
+ llm_response_text, model_instance = call_llm_api(general_knowledge_prompt)
1228
+ print("\n--- DEBUG INFO FOR RAG ---")
1229
+ print("Retrieved Context Sent to LLM (first 500 chars):")
1230
+ print(context_for_llm[:500] + "..." if len(context_for_llm) > 500 else context_for_llm)
1231
+ print("\nRaw LLM Response:")
1232
+ print(llm_response_text)
1233
+ print("--- END DEBUG INFO ---")
1234
+ llm_cost = 0
1235
+ if model_instance:
1236
+ try:
1237
+ input_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(prompt_for_llm).total_tokens
1238
+ output_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(llm_response_text).total_tokens
1239
+ print(f" DEBUG: LLM Input tokens: {input_llm_tokens}")
1240
+ print(f" DEBUG: LLM Output tokens: {output_llm_tokens}")
1241
+ llm_cost = (input_llm_tokens / 1000) * PRICE_PER_1K_INPUT_LLM + \
1242
+ (output_llm_tokens / 1000) * PRICE_PER_1K_OUTPUT_LLM
1243
+ print(f" DEBUG: Estimated LLM cost: ${llm_cost:.6f}")
1244
+ except Exception as e:
1245
+ print(f" DEBUG: Error counting LLM tokens: {e}")
1246
+ llm_cost = 0
1247
+
1248
+ total_query_cost += current_embedding_cost + llm_cost
1249
+ print("total query cost in again: ", total_query_cost)
1250
+ metadata_list_niche = parse_multi_sample_llm_output(llm_response_text, again_output_format)
1251
+ print(f"metadata list output for {again_output_format}: {metadata_list}")
1252
+ for key_niche in metadata_list_niche:
1253
+ if key_niche not in outputs.keys():
1254
+ output_acc[key_niche] = metadata_list_niche[key_niche]
1255
+
1256
  else:
1257
+ output_acc[key] = metadata_list[key]
1258
+ outputs[acc]["predicted_output"] = output_acc
1259
+ outputs[acc]["total_query_cost"] = total_query_cost
1260
+ print("total cost: ", total_query_cost)
1261
+ print(f"total output of {acc}: {outputs[acc]}")
1262
+ return outputs