mtDNAclassifier_flask / core /pipeline.py
linh-hk's picture
First version
a96bcc0
# test1: MJ17 direct
# test2: "A1YU101" thailand cross-ref
# test3: "EBK109" thailand cross-ref
# test4: "OQ731952"/"BST115" for search query title: "South Asian maternal and paternal lineages in southern Thailand and"
import os, io, time, re, json
import pandas as pd
import subprocess
import multiprocessing
from pathlib import Path
from typing import Any, Dict, List, Optional
import google.generativeai as genai
# Google Drive (optional)
from google.oauth2.service_account import Credentials
from googleapiclient.discovery import build
from googleapiclient.http import MediaFileUpload, MediaIoBaseDownload
import gspread
from oauth2client.service_account import ServiceAccountCredentials
# ---- core modules (must exist in your project) ----
import core.mtdna_classifier as mtdna_classifier
import core.data_preprocess as data_preprocess
import core.model as model
import core.smart_fallback as smart_fallback
import core.standardize_location as standardize_location
from core.NER.html import extractHTML
from core.drive_utils import *
# def run_with_timeout(func, args=(), kwargs={}, timeout=20):
# """
# Runs `func` with timeout in seconds. Kills if it exceeds.
# Returns: (success, result or None)
# """
# def wrapper(q, *args, **kwargs):
# try:
# q.put(func(*args, **kwargs))
# except Exception as e:
# q.put(e)
# q = multiprocessing.Queue()
# p = multiprocessing.Process(target=wrapper, args=(q, *args), kwargs=kwargs)
# p.start()
# p.join(timeout)
# if p.is_alive():
# p.terminate()
# p.join()
# print(f"⏱️ Timeout exceeded ({timeout} sec) — function killed.")
# return False, None
# else:
# result = q.get()
# if isinstance(result, Exception):
# raise result
# return True, result
# def run_with_timeout(func, args=(), kwargs={}, timeout=30):
# import concurrent.futures
# with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
# future = executor.submit(func, *args, **kwargs)
# try:
# return True, future.result(timeout=timeout)
# except concurrent.futures.TimeoutError:
# print(f"⏱️ Timeout exceeded ({timeout} sec) — function killed.")
# return False, None
def run_with_timeout(func, args=(), kwargs={}, timeout=30):
def wrapper(q, *args, **kwargs):
try:
result = func(*args, **kwargs)
q.put((True, result))
except Exception as e:
q.put((False, e))
q = multiprocessing.Queue()
p = multiprocessing.Process(target=wrapper, args=(q, *args), kwargs=kwargs)
p.start()
p.join(timeout)
if p.is_alive():
p.terminate()
p.join()
print(f"⏱️ Timeout exceeded ({timeout} sec) — function killed.")
return False, None
if not q.empty():
success, result = q.get()
if success:
return True, result
else:
raise result # re-raise exception if needed
return False, None
def time_it(func, *args, **kwargs):
"""
Measure how long a function takes to run and return its result + time.
"""
start = time.time()
result = func(*args, **kwargs)
end = time.time()
elapsed = end - start
print(f"⏱️ '{func.__name__}' took {elapsed:.3f} seconds")
return result, elapsed
# --- Define Pricing Constants (for Gemini 1.5 Flash & text-embedding-004) ---
def unique_preserve_order(seq):
seen = set()
return [x for x in seq if not (x in seen or seen.add(x))]
# Main execution
def pipeline_with_gemini(accessions,stop_flag=None, niche_cases=None, save_df=None):
# output: country, sample_type, ethnic, location, money_cost, time_cost, explain
# there can be one accession number in the accessions
# Prices are per 1,000 tokens
# Before each big step:
if stop_flag is not None and stop_flag.value:
print(f"🛑 Stop detected before starting {accession}, aborting early...")
return {}
# PRICE_PER_1K_INPUT_LLM = 0.000075 # $0.075 per 1M tokens
# PRICE_PER_1K_OUTPUT_LLM = 0.0003 # $0.30 per 1M tokens
# PRICE_PER_1K_EMBEDDING_INPUT = 0.000025 # $0.025 per 1M tokens
# Gemini 2.5 Flash-Lite pricing per 1,000 tokens
PRICE_PER_1K_INPUT_LLM = 0.00010 # $0.10 per 1M input tokens
PRICE_PER_1K_OUTPUT_LLM = 0.00040 # $0.40 per 1M output tokens
# Embedding-001 pricing per 1,000 input tokens
PRICE_PER_1K_EMBEDDING_INPUT = 0.00015 # $0.15 per 1M input tokens
if not accessions:
print("no input")
return None
else:
accs_output = {}
#genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
genai.configure(api_key=os.getenv("GOOGLE_API_KEY_BACKUP"))
for acc in accessions:
print("start gemini: ", acc)
start = time.time()
total_cost_title = 0
jsonSM, links, article_text = {},[], ""
acc_score = { "isolate": "",
"country":{},
"sample_type":{},
#"specific_location":{},
#"ethnicity":{},
"query_cost":total_cost_title,
"time_cost":None,
"source":links,
"file_chunk":"",
"file_all_output":""}
if niche_cases:
for niche in niche_cases:
acc_score[niche] = {}
meta = mtdna_classifier.fetch_ncbi_metadata(acc)
country, spe_loc, ethnic, sample_type, col_date, iso, title, doi, pudID, features = meta["country"], meta["specific_location"], meta["ethnicity"], meta["sample_type"], meta["collection_date"], meta["isolate"], meta["title"], meta["doi"], meta["pubmed_id"], meta["all_features"]
acc_score["isolate"] = iso
print("meta: ",meta)
meta_expand = smart_fallback.fetch_ncbi(acc)
print("meta expand: ", meta_expand)
# set up step: create the folder to save document
chunk, all_output = "",""
if pudID:
id = str(pudID)
saveTitle = title
else:
try:
author_name = meta_expand["authors"].split(',')[0] # Use last name only
except:
author_name = meta_expand["authors"]
saveTitle = title + "_" + col_date + "_" + author_name
if title.lower() == "unknown" and col_date.lower()=="unknown" and author_name.lower() == "unknown":
saveTitle += "_" + acc
id = "DirectSubmission"
# folder_path = Path("/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/"+str(id))
# if not folder_path.exists():
# cmd = f'mkdir /content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/{id}'
# result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# print("data/"+str(id) +" created.")
# else:
# print("data/"+str(id) +" already exists.")
# saveLinkFolder = "/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/"+str(id)
# parent_folder_id = get_or_create_drive_folder(GDRIVE_PARENT_FOLDER_NAME)
# data_folder_id = get_or_create_drive_folder(GDRIVE_DATA_FOLDER_NAME, parent_id=parent_folder_id)
# sample_folder_id = get_or_create_drive_folder(str(id), parent_id=data_folder_id)
data_folder_id = GDRIVE_DATA_FOLDER_NAME # Use the shared folder directly
sample_folder_id = get_or_create_drive_folder(str(id), parent_id=data_folder_id)
print("sample folder id: ", sample_folder_id)
# Define document names
if len(saveTitle) > 50:
saveName = saveTitle[:50]
saveName = saveName.replace(" ", "_")
chunk_filename = f"{saveName}_merged_document.docx"
all_filename = f"{saveName}_all_merged_document.docx"
else:
saveName = saveTitle.replace(" ", "_")
chunk_filename = f"{saveName}_merged_document.docx"
all_filename = f"{saveName}_all_merged_document.docx"
print("chunk file name and all filename: ", chunk_filename, all_filename)
# Define local temp paths for reading/writing
# import tempfile
# tmp_dir = tempfile.mkdtemp()
LOCAL_TEMP_DIR = "/mnt/data/generated_docs"
os.makedirs(LOCAL_TEMP_DIR, exist_ok=True)
file_chunk_path = os.path.join(LOCAL_TEMP_DIR, chunk_filename)
file_all_path = os.path.join(LOCAL_TEMP_DIR, all_filename)
# file_chunk_path = os.path.join(tempfile.gettempdir(), chunk_filename)
# file_all_path = os.path.join(tempfile.gettempdir(), all_filename)
if stop_flag is not None and stop_flag.value:
print(f"🛑 Stop processing {accession}, aborting early...")
return {}
print("this is file chunk path: ", file_chunk_path)
chunk_id = find_drive_file(chunk_filename, sample_folder_id)
all_id = find_drive_file(all_filename, sample_folder_id)
if chunk_id and all_id:
print("✅ Files already exist in Google Drive. Downloading them...")
chunk_exists = download_file_from_drive(chunk_filename, sample_folder_id, file_chunk_path)
all_exists = download_file_from_drive(all_filename, sample_folder_id, file_all_path)
acc_score["file_chunk"] = str(chunk_filename)
acc_score["file_all_output"] = str(all_filename)
print("chunk_id and all_id: ")
print(chunk_id, all_id)
print("file chunk and all output saved in acc score: ", acc_score["file_chunk"], acc_score["file_all_output"])
file = drive_service.files().get(fileId="1LUJRTrq8yt4S4lLwCvTmlxaKqpr0nvEn", fields="id, name, parents, webViewLink").execute()
print("📄 Name:", file["name"])
print("📁 Parent folder ID:", file["parents"][0])
print("🔗 View link:", file["webViewLink"])
# Read and parse these into `chunk` and `all_output`
else:
# 🔥 Remove any stale local copies
if os.path.exists(file_chunk_path):
os.remove(file_chunk_path)
print(f"🗑️ Removed stale: {file_chunk_path}")
if os.path.exists(file_all_path):
os.remove(file_all_path)
print(f"🗑️ Removed stale: {file_all_path}")
# 🔥 Remove the local file first if it exists
# if os.path.exists(file_chunk_path):
# os.remove(file_chunk_path)
# print("remove chunk path")
# if os.path.exists(file_all_path):
# os.remove(file_all_path)
# print("remove all path")
# Try to download if already exists on Drive
chunk_exists = download_file_from_drive(chunk_filename, sample_folder_id, file_chunk_path)
all_exists = download_file_from_drive(all_filename, sample_folder_id, file_all_path)
print("chunk exist: ", chunk_exists)
# first way: ncbi method
print("country.lower: ",country.lower())
if country.lower() != "unknown":
stand_country = standardize_location.smart_country_lookup(country.lower())
print("stand_country: ", stand_country)
if stand_country.lower() != "not found":
acc_score["country"][stand_country.lower()] = ["ncbi"]
else: acc_score["country"][country.lower()] = ["ncbi"]
# if spe_loc.lower() != "unknown":
# acc_score["specific_location"][spe_loc.lower()] = ["ncbi"]
# if ethnic.lower() != "unknown":
# acc_score["ethnicity"][ethnic.lower()] = ["ncbi"]
if sample_type.lower() != "unknown":
acc_score["sample_type"][sample_type.lower()] = ["ncbi"]
# second way: LLM model
# Preprocess the input token
print(acc_score)
accession, isolate = None, None
if acc != "unknown": accession = acc
if iso != "unknown": isolate = iso
if stop_flag is not None and stop_flag.value:
print(f"🛑 Stop processing {accession}, aborting early...")
return {}
# check doi first
print("chunk filename: ", chunk_filename)
if chunk_exists:
print("File chunk exists!")
if not chunk:
print("start to get chunk")
text, table, document_title = model.read_docx_text(file_chunk_path)
chunk = data_preprocess.normalize_for_overlap(text) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table))
if str(chunk_filename) != "":
print("first time have chunk path at chunk exist: ", str(chunk_filename))
acc_score["file_chunk"] = str(chunk_filename)
if all_exists:
print("File all output exists!")
if not all_output:
text_all, table_all, document_title_all = model.read_docx_text(file_all_path)
all_output = data_preprocess.normalize_for_overlap(text_all) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table_all))
if str(all_filename) != "":
print("first time have all path at all exist: ", str(all_filename))
acc_score["file_all_output"] = str(all_filename)
print("acc sscore for file all output and chunk: ", acc_score["file_all_output"], acc_score["file_chunk"])
if len(acc_score["file_all_output"]) == 0 and len(acc_score["file_chunk"]) == 0:
if doi != "unknown":
link = 'https://doi.org/' + doi
# get the file to create listOfFile for each id
print("link of doi: ", link)
html = extractHTML.HTML("",link)
jsonSM = html.getSupMaterial()
article_text = html.getListSection()
if article_text:
if "Just a moment...Enable JavaScript and cookies to continue".lower() not in article_text.lower() or "403 Forbidden Request".lower() not in article_text.lower():
links.append(link)
if jsonSM:
links += sum((jsonSM[key] for key in jsonSM),[])
# no doi then google custom search api
if doi=="unknown" or len(article_text) == 0 or "Just a moment...Enable JavaScript and cookies to continue".lower() in article_text.lower() or "403 Forbidden Request".lower() in article_text.lower():
# might find the article
print("no article text, start tem link")
#tem_links = mtdna_classifier.search_google_custom(title, 2)
tem_links = smart_fallback.smart_google_search(meta_expand)
print("tem links: ", tem_links)
tem_link_acc = smart_fallback.google_accession_search(acc)
tem_links += tem_link_acc
tem_links = unique_preserve_order(tem_links)
print("tem link before filtering: ", tem_links)
# filter the quality link
print("saveLinkFolder as sample folder id: ", sample_folder_id)
print("start the smart filter link")
if stop_flag is not None and stop_flag.value:
print(f"🛑 Stop processing {accession}, aborting early...")
return {}
# success_process, output_process = run_with_timeout(smart_fallback.filter_links_by_metadata,args=(tem_links,sample_folder_id),kwargs={"accession":acc})
# if success_process:
# links = output_process
# print("yes succeed for smart filter link")
# else:
# print("no suceed, fallback to all tem links")
# links = tem_links
links = smart_fallback.filter_links_by_metadata(tem_links, saveLinkFolder=sample_folder_id, accession=acc, stop_flag=stop_flag)
print("this is links: ",links)
links = unique_preserve_order(links)
acc_score["source"] = links
else:
print("inside the try of reusing chunk or all output")
#print("chunk filename: ", str(chunks_filename))
try:
temp_source = False
if save_df is not None and not save_df.empty:
print("save df not none")
print("chunk file name: ",str(chunk_filename))
print("all filename: ",str(all_filename))
if acc_score["file_chunk"]:
link = save_df.loc[save_df["file_chunk"]==acc_score["file_chunk"],"Sources"].iloc[0]
#link = row["Sources"].iloc[0]
if "http" in link:
print("yeah http in save df source")
acc_score["source"] = [x for x in link.split("\n") if x.strip()]#row["Sources"].tolist()
else: # temporary
print("tempo source")
#acc_score["source"] = [str(all_filename), str(chunks_filename)]
temp_source = True
elif acc_score["file_all_output"]:
link = save_df.loc[save_df["file_all_output"]==acc_score["file_all_output"],"Sources"].iloc[0]
#link = row["Sources"].iloc[0]
print(link)
print("list of link")
print([x for x in link.split("\n") if x.strip()])
if "http" in link:
print("yeah http in save df source")
acc_score["source"] = [x for x in link.split("\n") if x.strip()]#row["Sources"].tolist()
else: # temporary
print("tempo source")
#acc_score["source"] = [str(all_filename), str(chunks_filename)]
temp_source = True
else: # temporary
print("tempo source")
#acc_score["source"] = [str(file_all_path), str(file_chunk_path)]
temp_source = True
else: # temporary
print("tempo source")
#acc_score["source"] = [str(file_all_path), str(file_chunk_path)]
temp_source = True
if temp_source:
print("temp source is true so have to try again search link")
if doi != "unknown":
link = 'https://doi.org/' + doi
# get the file to create listOfFile for each id
print("link of doi: ", link)
html = extractHTML.HTML("",link)
jsonSM = html.getSupMaterial()
article_text = html.getListSection()
if article_text:
if "Just a moment...Enable JavaScript and cookies to continue".lower() not in article_text.lower() or "403 Forbidden Request".lower() not in article_text.lower():
links.append(link)
if jsonSM:
links += sum((jsonSM[key] for key in jsonSM),[])
# no doi then google custom search api
if doi=="unknown" or len(article_text) == 0 or "Just a moment...Enable JavaScript and cookies to continue".lower() in article_text.lower() or "403 Forbidden Request".lower() in article_text.lower():
# might find the article
print("no article text, start tem link")
#tem_links = mtdna_classifier.search_google_custom(title, 2)
tem_links = smart_fallback.smart_google_search(meta_expand)
print("tem links: ", tem_links)
tem_link_acc = smart_fallback.google_accession_search(acc)
tem_links += tem_link_acc
tem_links = unique_preserve_order(tem_links)
print("tem link before filtering: ", tem_links)
# filter the quality link
print("saveLinkFolder as sample folder id: ", sample_folder_id)
print("start the smart filter link")
if stop_flag is not None and stop_flag.value:
print(f"🛑 Stop processing {accession}, aborting early...")
return {}
# success_process, output_process = run_with_timeout(smart_fallback.filter_links_by_metadata,args=(tem_links,sample_folder_id),kwargs={"accession":acc})
# if success_process:
# links = output_process
# print("yes succeed for smart filter link")
# else:
# print("no suceed, fallback to all tem links")
# links = tem_links
links = smart_fallback.filter_links_by_metadata(tem_links, saveLinkFolder=sample_folder_id, accession=acc, stop_flag=stop_flag)
print("this is links: ",links)
links = unique_preserve_order(links)
acc_score["source"] = links
except:
print("except for source")
acc_score["source"] = []
# chunk_path = "/"+saveTitle+"_merged_document.docx"
# all_path = "/"+saveTitle+"_all_merged_document.docx"
# # if chunk and all output not exist yet
# file_chunk_path = saveLinkFolder + chunk_path
# file_all_path = saveLinkFolder + all_path
# if os.path.exists(file_chunk_path):
# print("File chunk exists!")
# if not chunk:
# text, table, document_title = model.read_docx_text(file_chunk_path)
# chunk = data_preprocess.normalize_for_overlap(text) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table))
# if os.path.exists(file_all_path):
# print("File all output exists!")
# if not all_output:
# text_all, table_all, document_title_all = model.read_docx_text(file_all_path)
# all_output = data_preprocess.normalize_for_overlap(text_all) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table_all))
if stop_flag is not None and stop_flag.value:
print(f"🛑 Stop processing {accession}, aborting early...")
return {}
# print("chunk filename: ", chunk_filename)
# if chunk_exists:
# print("File chunk exists!")
# if not chunk:
# print("start to get chunk")
# text, table, document_title = model.read_docx_text(file_chunk_path)
# chunk = data_preprocess.normalize_for_overlap(text) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table))
# if str(chunk_filename) != "":
# print("first time have chunk path at chunk exist: ", str(chunk_filename))
# acc_score["file_chunk"] = str(chunk_filename)
# if all_exists:
# print("File all output exists!")
# if not all_output:
# text_all, table_all, document_title_all = model.read_docx_text(file_all_path)
# all_output = data_preprocess.normalize_for_overlap(text_all) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table_all))
# if str(all_filename) != "":
# print("first time have all path at all exist: ", str(all_filename))
# acc_score["file_all_output"] = str(all_filename)
if not chunk and not all_output:
print("not chunk and all output")
# else: check if we can reuse these chunk and all output of existed accession to find another
if str(chunk_filename) != "":
print("first time have chunk path: ", str(chunk_filename))
acc_score["file_chunk"] = str(chunk_filename)
if str(all_filename) != "":
print("first time have all path: ", str(all_filename))
acc_score["file_all_output"] = str(all_filename)
if links:
for link in links:
print(link)
# if len(all_output) > 1000*1000:
# all_output = data_preprocess.normalize_for_overlap(all_output)
# print("after normalizing all output: ", len(all_output))
if len(data_preprocess.normalize_for_overlap(all_output)) > 600000:
print("break here")
break
if iso != "unknown": query_kw = iso
else: query_kw = acc
#text_link, tables_link, final_input_link = data_preprocess.preprocess_document(link,saveLinkFolder, isolate=query_kw)
success_process, output_process = run_with_timeout(data_preprocess.preprocess_document,args=(link,sample_folder_id),kwargs={"isolate":query_kw,"accession":acc},timeout=100)
if stop_flag is not None and stop_flag.value:
print(f"🛑 Stop processing {accession}, aborting early...")
return {}
if success_process:
text_link, tables_link, final_input_link = output_process[0], output_process[1], output_process[2]
print("yes succeed for process document")
else: text_link, tables_link, final_input_link = "", "", ""
context = data_preprocess.extract_context(final_input_link, query_kw)
if context != "Sample ID not found.":
if len(data_preprocess.normalize_for_overlap(chunk)) < 1000*1000:
success_chunk, the_output_chunk = run_with_timeout(data_preprocess.merge_texts_skipping_overlap,args=(chunk, context))
if stop_flag is not None and stop_flag.value:
print(f"🛑 Stop processing {accession}, aborting early...")
return {}
if success_chunk:
chunk = the_output_chunk#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
print("yes succeed for chunk")
else:
chunk += context
print("len context: ", len(context))
print("basic fall back")
print("len chunk after: ", len(chunk))
if len(final_input_link) > 1000*1000:
if context != "Sample ID not found.":
final_input_link = context
else:
final_input_link = data_preprocess.normalize_for_overlap(final_input_link)
if len(final_input_link) > 1000 *1000:
final_input_link = final_input_link[:100000]
if len(data_preprocess.normalize_for_overlap(all_output)) < int(100000) and len(final_input_link)<100000:
print("Running merge_texts_skipping_overlap with timeout")
success, the_output = run_with_timeout(data_preprocess.merge_texts_skipping_overlap,args=(all_output, final_input_link),timeout=30)
if stop_flag is not None and stop_flag.value:
print(f"🛑 Stop processing {accession}, aborting early...")
return {}
print("Returned from timeout logic")
if success:
all_output = the_output#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
print("yes succeed")
else:
print("len all output: ", len(all_output))
print("len final input link: ", len(final_input_link))
all_output += final_input_link
print("len final input: ", len(final_input_link))
print("basic fall back")
else:
print("both/either all output or final link too large more than 100000")
print("len all output: ", len(all_output))
print("len final input link: ", len(final_input_link))
all_output += final_input_link
print("len final input: ", len(final_input_link))
print("basic fall back")
print("len all output after: ", len(all_output))
#country_pro, chunk, all_output = data_preprocess.process_inputToken(links, saveLinkFolder, accession=accession, isolate=isolate)
if stop_flag is not None and stop_flag.value:
print(f"🛑 Stop processing {accession}, aborting early...")
return {}
else:
chunk = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
if not chunk: chunk = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
if not all_output: all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
if len(all_output) > 1*1024*1024:
all_output = data_preprocess.normalize_for_overlap(all_output)
if len(all_output) > 1*1024*1024:
all_output = all_output[:1*1024*1024]
print("chunk len: ", len(chunk))
print("all output len: ", len(all_output))
data_preprocess.save_text_to_docx(chunk, file_chunk_path)
data_preprocess.save_text_to_docx(all_output, file_all_path)
# Later when saving new files
# data_preprocess.save_text_to_docx(chunk, chunk_filename, sample_folder_id)
# data_preprocess.save_text_to_docx(all_output, all_filename, sample_folder_id)
# Upload to Drive
result_chunk_upload = upload_file_to_drive(file_chunk_path, chunk_filename, sample_folder_id)
result_all_upload = upload_file_to_drive(file_all_path, all_filename, sample_folder_id)
print("UPLOAD RESULT FOR CHUNK: ", result_chunk_upload)
print(f"🔗 Uploaded file: https://drive.google.com/file/d/{result_chunk_upload}/view")
print("here 1")
# else:
# final_input = ""
# if all_output:
# final_input = all_output
# else:
# if chunk: final_input = chunk
# #data_preprocess.merge_texts_skipping_overlap(final_input, all_output)
# if final_input:
# keywords = []
# if iso != "unknown": keywords.append(iso)
# if acc != "unknown": keywords.append(acc)
# for keyword in keywords:
# chunkBFS = data_preprocess.get_contextual_sentences_BFS(final_input, keyword)
# countryDFS, chunkDFS = data_preprocess.get_contextual_sentences_DFS(final_input, keyword)
# chunk = data_preprocess.merge_texts_skipping_overlap(chunk, chunkDFS)
# chunk = data_preprocess.merge_texts_skipping_overlap(chunk, chunkBFS)
# Define paths for cached RAG assets
# faiss_index_path = saveLinkFolder+"/faiss_index.bin"
# document_chunks_path = saveLinkFolder+"/document_chunks.json"
# structured_lookup_path = saveLinkFolder+"/structured_lookup.json"
print("here 2")
faiss_filename = "faiss_index.bin"
chunks_filename = "document_chunks.json"
lookup_filename = "structured_lookup.json"
print("name of faiss: ", faiss_filename)
faiss_index_path = os.path.join(LOCAL_TEMP_DIR, faiss_filename)
document_chunks_path = os.path.join(LOCAL_TEMP_DIR, chunks_filename)
structured_lookup_path = os.path.join(LOCAL_TEMP_DIR, lookup_filename)
print("name if faiss path: ", faiss_index_path)
# 🔥 Remove the local file first if it exists
print("start faiss id and also the sample folder id is: ", sample_folder_id)
faiss_id = find_drive_file(faiss_filename, sample_folder_id)
print("done faiss id")
document_id = find_drive_file(chunks_filename, sample_folder_id)
structure_id = find_drive_file(lookup_filename, sample_folder_id)
if faiss_id and document_id and structure_id:
print("✅ 3 Files already exist in Google Drive. Downloading them...")
download_file_from_drive(faiss_filename, sample_folder_id, faiss_index_path)
download_file_from_drive(chunks_filename, sample_folder_id, document_chunks_path)
download_file_from_drive(lookup_filename, sample_folder_id, structured_lookup_path)
# Read and parse these into `chunk` and `all_output`
else:
"one of id not exist"
if os.path.exists(faiss_index_path):
print("faiss index exist and start to remove: ", faiss_index_path)
os.remove(faiss_index_path)
if os.path.exists(document_chunks_path):
os.remove(document_chunks_path)
if os.path.exists(structured_lookup_path):
os.remove(structured_lookup_path)
print("start to download the faiss, chunk, lookup")
download_file_from_drive(faiss_filename, sample_folder_id, faiss_index_path)
download_file_from_drive(chunks_filename, sample_folder_id, document_chunks_path)
download_file_from_drive(lookup_filename, sample_folder_id, structured_lookup_path)
try:
print("try gemini 2.5")
print("move to load rag")
master_structured_lookup, faiss_index, document_chunks = model.load_rag_assets(
faiss_index_path, document_chunks_path, structured_lookup_path
)
global_llm_model_for_counting_tokens = genai.GenerativeModel('gemini-1.5-flash-latest')
if not all_output:
if chunk: all_output = chunk
else: all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
if faiss_index is None:
print("\nBuilding RAG assets (structured lookup, FAISS index, chunks)...")
total_doc_embedding_tokens = global_llm_model_for_counting_tokens.count_tokens(
all_output
).total_tokens
initial_embedding_cost = (total_doc_embedding_tokens / 1000) * PRICE_PER_1K_EMBEDDING_INPUT
total_cost_title += initial_embedding_cost
print(f"Initial one-time embedding cost for '{file_all_path}' ({total_doc_embedding_tokens} tokens): ${initial_embedding_cost:.6f}")
master_structured_lookup, faiss_index, document_chunks, plain_text_content = model.build_vector_index_and_data(
file_all_path, faiss_index_path, document_chunks_path, structured_lookup_path
)
else:
print("\nRAG assets loaded from file. No re-embedding of entire document will occur.")
plain_text_content_all, table_strings_all, document_title_all = model.read_docx_text(file_all_path)
master_structured_lookup['document_title'] = master_structured_lookup.get('document_title', document_title_all)
if stop_flag is not None and stop_flag.value:
print(f"🛑 Stop processing {accession}, aborting early...")
return {}
primary_word = iso
alternative_word = acc
print(f"\n--- General Query: Primary='{primary_word}' (Alternative='{alternative_word}') ---")
if features.lower() not in all_output.lower():
all_output += ". NCBI Features: " + features
# country, sample_type, method_used, ethnic, spe_loc, total_query_cost = model.query_document_info(
# primary_word, alternative_word, meta, master_structured_lookup, faiss_index, document_chunks,
# model.call_llm_api, chunk=chunk, all_output=all_output)
print("this is chunk for the model")
print(chunk)
print("this is all output for the model")
print(all_output)
if stop_flag is not None and stop_flag.value:
print(f"🛑 Stop processing {accession}, aborting early...")
return {}
country, sample_type, method_used, country_explanation, sample_type_explanation, total_query_cost = model.query_document_info(
primary_word, alternative_word, meta, master_structured_lookup, faiss_index, document_chunks,
model.call_llm_api, chunk=chunk, all_output=all_output)
print("pass query of 2.5")
except:
print("try gemini 1.5")
country, sample_type, ethnic, spe_loc, method_used, country_explanation, sample_type_explanation, ethnicity_explanation, specific_loc_explanation, total_query_cost = model.query_document_info(
primary_word, alternative_word, meta, master_structured_lookup, faiss_index, document_chunks,
model.call_llm_api, chunk=chunk, all_output=all_output, model_ai="gemini-1.5-flash-latest")
print("yeah pass the query of 1.5")
print("country using ai: ", country)
print("sample type using ai: ", sample_type)
# if len(country) == 0: country = "unknown"
# if len(sample_type) == 0: sample_type = "unknown"
# if country_explanation: country_explanation = "-"+country_explanation
# else: country_explanation = ""
# if sample_type_explanation: sample_type_explanation = "-"+sample_type_explanation
# else: sample_type_explanation = ""
if len(country) == 0: country = "unknown"
if len(sample_type) == 0: sample_type = "unknown"
if country_explanation and country_explanation!="unknown": country_explanation = "-"+country_explanation
else: country_explanation = ""
if sample_type_explanation and sample_type_explanation!="unknown": sample_type_explanation = "-"+sample_type_explanation
else: sample_type_explanation = ""
if method_used == "unknown": method_used = ""
if country.lower() != "unknown":
stand_country = standardize_location.smart_country_lookup(country.lower())
if stand_country.lower() != "not found":
if stand_country.lower() in acc_score["country"]:
if country_explanation:
acc_score["country"][stand_country.lower()].append(method_used + country_explanation)
else:
acc_score["country"][stand_country.lower()] = [method_used + country_explanation]
else:
if country.lower() in acc_score["country"]:
if country_explanation:
if len(method_used + country_explanation) > 0:
acc_score["country"][country.lower()].append(method_used + country_explanation)
else:
if len(method_used + country_explanation) > 0:
acc_score["country"][country.lower()] = [method_used + country_explanation]
# if spe_loc.lower() != "unknown":
# if spe_loc.lower() in acc_score["specific_location"]:
# acc_score["specific_location"][spe_loc.lower()].append(method_used)
# else:
# acc_score["specific_location"][spe_loc.lower()] = [method_used]
# if ethnic.lower() != "unknown":
# if ethnic.lower() in acc_score["ethnicity"]:
# acc_score["ethnicity"][ethnic.lower()].append(method_used)
# else:
# acc_score["ethnicity"][ethnic.lower()] = [method_used]
if sample_type.lower() != "unknown":
if sample_type.lower() in acc_score["sample_type"]:
if len(method_used + sample_type_explanation) > 0:
acc_score["sample_type"][sample_type.lower()].append(method_used + sample_type_explanation)
else:
if len(method_used + sample_type_explanation)> 0:
acc_score["sample_type"][sample_type.lower()] = [method_used + sample_type_explanation]
total_cost_title += total_query_cost
if stop_flag is not None and stop_flag.value:
print(f"🛑 Stop processing {accession}, aborting early...")
return {}
# last resort: combine all information to give all output otherwise unknown
if len(acc_score["country"]) == 0 or len(acc_score["sample_type"]) == 0 or acc_score["country"] == "unknown" or acc_score["sample_type"] == "unknown":
text = ""
for key in meta_expand:
text += str(key) + ": " + meta_expand[key] + "\n"
if len(data_preprocess.normalize_for_overlap(all_output)) > 0:
text += data_preprocess.normalize_for_overlap(all_output)
if len(data_preprocess.normalize_for_overlap(chunk)) > 0:
text += data_preprocess.normalize_for_overlap(chunk)
text += ". NCBI Features: " + features
print("this is text for the last resort model")
print(text)
country, sample_type, method_used, country_explanation, sample_type_explanation, total_query_cost = model.query_document_info(
primary_word, alternative_word, meta, master_structured_lookup, faiss_index, document_chunks,
model.call_llm_api, chunk=text, all_output=text)
print("this is last resort results: ")
print("country: ", country)
print("sample type: ", sample_type)
if len(country) == 0: country = "unknown"
if len(sample_type) == 0: sample_type = "unknown"
# if country_explanation: country_explanation = "-"+country_explanation
# else: country_explanation = ""
# if sample_type_explanation: sample_type_explanation = "-"+sample_type_explanation
# else: sample_type_explanation = ""
if country_explanation and country_explanation!="unknown": country_explanation = "-"+country_explanation
else: country_explanation = ""
if sample_type_explanation and sample_type_explanation!="unknown": sample_type_explanation = "-"+sample_type_explanation
else: sample_type_explanation = ""
if method_used == "unknown": method_used = ""
if country.lower() != "unknown":
stand_country = standardize_location.smart_country_lookup(country.lower())
if stand_country.lower() != "not found":
if stand_country.lower() in acc_score["country"]:
if country_explanation:
acc_score["country"][stand_country.lower()].append(method_used + country_explanation)
else:
acc_score["country"][stand_country.lower()] = [method_used + country_explanation]
else:
if country.lower() in acc_score["country"]:
if country_explanation:
if len(method_used + country_explanation) > 0:
acc_score["country"][country.lower()].append(method_used + country_explanation)
else:
if len(method_used + country_explanation) > 0:
acc_score["country"][country.lower()] = [method_used + country_explanation]
if sample_type.lower() != "unknown":
if sample_type.lower() in acc_score["sample_type"]:
if len(method_used + sample_type_explanation) > 0:
acc_score["sample_type"][sample_type.lower()].append(method_used + sample_type_explanation)
else:
if len(method_used + sample_type_explanation)> 0:
acc_score["sample_type"][sample_type.lower()] = [method_used + sample_type_explanation]
total_cost_title += total_query_cost
end = time.time()
#total_cost_title += total_query_cost
acc_score["query_cost"] = f"{total_cost_title:.6f}"
elapsed = end - start
acc_score["time_cost"] = f"{elapsed:.3f} seconds"
accs_output[acc] = acc_score
print(accs_output[acc])
return accs_output