|
|
import re, os, json |
|
|
import pycountry, faiss |
|
|
from docx import Document |
|
|
import numpy as np |
|
|
from collections import defaultdict |
|
|
import ast |
|
|
import math |
|
|
import core.data_preprocess |
|
|
import core.mtdna_classifier |
|
|
|
|
|
import google.generativeai as genai |
|
|
|
|
|
|
|
|
genai.configure(api_key=os.getenv("GOOGLE_API_KEY_BACKUP")) |
|
|
|
|
|
import nltk |
|
|
from nltk.corpus import stopwords |
|
|
try: |
|
|
nltk.data.find('corpora/stopwords') |
|
|
except LookupError: |
|
|
nltk.download('stopwords') |
|
|
nltk.download('punkt_tab') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PRICE_PER_1K_INPUT_LLM = 0.00010 |
|
|
PRICE_PER_1K_OUTPUT_LLM = 0.00040 |
|
|
|
|
|
|
|
|
PRICE_PER_1K_EMBEDDING_INPUT = 0.00015 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_embedding(text, task_type="RETRIEVAL_DOCUMENT"): |
|
|
"""Safe Gemini 1.5 embedding call with fallback.""" |
|
|
import numpy as np |
|
|
try: |
|
|
if not text or len(text.strip()) == 0: |
|
|
raise ValueError("Empty text cannot be embedded.") |
|
|
result = genai.embed_content( |
|
|
model="models/text-embedding-004", |
|
|
content=text, |
|
|
task_type=task_type |
|
|
) |
|
|
return np.array(result['embedding'], dtype='float32') |
|
|
except Exception as e: |
|
|
print(f"❌ Embedding error: {e}") |
|
|
return np.zeros(768, dtype='float32') |
|
|
|
|
|
|
|
|
def call_llm_api(prompt, model_name="gemini-2.5-flash-lite"): |
|
|
"""Calls a Google Gemini LLM with the given prompt.""" |
|
|
try: |
|
|
model = genai.GenerativeModel(model_name) |
|
|
response = model.generate_content(prompt) |
|
|
return response.text, model |
|
|
except Exception as e: |
|
|
print(f"Error calling LLM: {e}") |
|
|
return "Error: Could not get response from LLM API.", None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_docx_text(path): |
|
|
""" |
|
|
Reads text and extracts potential table-like strings from a .docx document. |
|
|
Separates plain text from structured [ [ ] ] list-like tables. |
|
|
Also attempts to extract a document title. |
|
|
""" |
|
|
doc = Document(path) |
|
|
plain_text_paragraphs = [] |
|
|
table_strings = [] |
|
|
document_title = "Unknown Document Title" |
|
|
|
|
|
|
|
|
title_paragraphs = [p.text.strip() for p in doc.paragraphs[:5] if p.text.strip()] |
|
|
if title_paragraphs: |
|
|
|
|
|
|
|
|
if len(title_paragraphs[0]) > 50 and "Human Genetics" not in title_paragraphs[0]: |
|
|
document_title = title_paragraphs[0] |
|
|
elif len(title_paragraphs) > 1 and len(title_paragraphs[1]) > 50 and "Human Genetics" not in title_paragraphs[1]: |
|
|
document_title = title_paragraphs[1] |
|
|
elif any("Complete mitochondrial genomes" in p for p in title_paragraphs): |
|
|
|
|
|
document_title = "Complete mitochondrial genomes of Thai and Lao populations indicate an ancient origin of Austroasiatic groups and demic diffusion in the spread of Tai–Kadai languages" |
|
|
|
|
|
current_table_lines = [] |
|
|
in_table_parsing_mode = False |
|
|
|
|
|
for p in doc.paragraphs: |
|
|
text = p.text.strip() |
|
|
if not text: |
|
|
continue |
|
|
|
|
|
|
|
|
if text.startswith("## Table "): |
|
|
if in_table_parsing_mode and current_table_lines: |
|
|
table_strings.append("\n".join(current_table_lines)) |
|
|
current_table_lines = [text] |
|
|
in_table_parsing_mode = True |
|
|
elif in_table_parsing_mode and (text.startswith("[") or text.startswith('"')): |
|
|
|
|
|
|
|
|
current_table_lines.append(text) |
|
|
else: |
|
|
|
|
|
|
|
|
if in_table_parsing_mode and current_table_lines: |
|
|
table_strings.append("\n".join(current_table_lines)) |
|
|
current_table_lines = [] |
|
|
in_table_parsing_mode = False |
|
|
plain_text_paragraphs.append(text) |
|
|
|
|
|
|
|
|
if current_table_lines: |
|
|
table_strings.append("\n".join(current_table_lines)) |
|
|
|
|
|
return "\n".join(plain_text_paragraphs), table_strings, document_title |
|
|
|
|
|
|
|
|
|
|
|
def parse_literal_python_list(table_str): |
|
|
list_match = re.search(r'(\[\s*\[\s*(?:.|\n)*?\s*\]\s*\])', table_str) |
|
|
|
|
|
if not list_match: |
|
|
if "table" in table_str.lower(): |
|
|
table_str += "]]" |
|
|
list_match = re.search(r'(\[\s*\[\s*(?:.|\n)*?\s*\]\s*\])', table_str) |
|
|
if list_match: |
|
|
try: |
|
|
matched_string = list_match.group(1) |
|
|
|
|
|
return ast.literal_eval(matched_string) |
|
|
except (ValueError, SyntaxError) as e: |
|
|
print(f"Error evaluating literal: {e}") |
|
|
return [] |
|
|
return [] |
|
|
|
|
|
|
|
|
_individual_code_parser = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE) |
|
|
def _parse_individual_code_parts(code_str): |
|
|
match = _individual_code_parser.search(code_str) |
|
|
if match: |
|
|
return match.group(1), match.group(2) |
|
|
return None, None |
|
|
|
|
|
|
|
|
def parse_sample_id_to_population_code(plain_text_content): |
|
|
sample_id_map = {} |
|
|
contiguous_ranges_data = defaultdict(list) |
|
|
|
|
|
|
|
|
section_start_marker = ["The sample identification of each population is as follows:","## table"] |
|
|
|
|
|
for s in section_start_marker: |
|
|
relevant_text_search = re.search( |
|
|
re.escape(s.lower()) + r"\s*(.*?)(?=\n##|\Z)", |
|
|
plain_text_content.lower(), |
|
|
re.DOTALL |
|
|
) |
|
|
if relevant_text_search: |
|
|
break |
|
|
|
|
|
if not relevant_text_search: |
|
|
print("Warning: 'Sample ID Population Code' section start marker not found or block empty.") |
|
|
return sample_id_map, contiguous_ranges_data |
|
|
|
|
|
relevant_text_block = relevant_text_search.group(1).strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapping_pattern = re.compile( |
|
|
r'\b([A-Z0-9]+\d+)(?:-([A-Z0-9]+\d+))?\s+([A-Z0-9]+)\b', |
|
|
re.IGNORECASE) |
|
|
|
|
|
range_expansion_count = 0 |
|
|
direct_id_count = 0 |
|
|
total_matches_found = 0 |
|
|
for match in mapping_pattern.finditer(relevant_text_block): |
|
|
total_matches_found += 1 |
|
|
id1_full_str, id2_full_str_opt, pop_code = match.groups() |
|
|
|
|
|
|
|
|
|
|
|
pop_code_upper = pop_code.upper() |
|
|
|
|
|
id1_prefix, id1_num_str = _parse_individual_code_parts(id1_full_str) |
|
|
if id1_prefix is None: |
|
|
|
|
|
continue |
|
|
|
|
|
if id2_full_str_opt: |
|
|
id2_prefix_opt, id2_num_str_opt = _parse_individual_code_parts(id2_full_str_opt) |
|
|
if id2_prefix_opt is None: |
|
|
|
|
|
sample_id_map[f"{id1_prefix.upper()}{id1_num_str}"] = pop_code_upper |
|
|
direct_id_count += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
if id1_prefix.lower() == id2_prefix_opt.lower(): |
|
|
|
|
|
try: |
|
|
start_num = int(id1_num_str) |
|
|
end_num = int(id2_num_str_opt) |
|
|
for num in range(start_num, end_num + 1): |
|
|
sample_id = f"{id1_prefix.upper()}{num}" |
|
|
sample_id_map[sample_id] = pop_code_upper |
|
|
range_expansion_count += 1 |
|
|
contiguous_ranges_data[id1_prefix.upper()].append( |
|
|
(start_num, end_num, pop_code_upper) |
|
|
) |
|
|
except ValueError: |
|
|
print(f" DEBUG_PARSING: ValueError in range conversion for {id1_num_str}-{id2_num_str_opt}. Adding endpoints only.") |
|
|
sample_id_map[f"{id1_prefix.upper()}{id1_num_str}"] = pop_code_upper |
|
|
sample_id_map[f"{id2_prefix_opt.upper()}{id2_num_str_opt}"] = pop_code_upper |
|
|
direct_id_count += 2 |
|
|
else: |
|
|
|
|
|
sample_id_map[f"{id1_prefix.upper()}{id1_num_str}"] = pop_code_upper |
|
|
sample_id_map[f"{id2_prefix_opt.upper()}{id2_num_str_opt}"] = pop_code_upper |
|
|
direct_id_count += 2 |
|
|
else: |
|
|
sample_id_map[f"{id1_prefix.upper()}{id1_num_str}"] = pop_code_upper |
|
|
direct_id_count += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return sample_id_map, contiguous_ranges_data |
|
|
|
|
|
country_keywords_regional_overrides = { |
|
|
"north thailand": "Thailand", "central thailand": "Thailand", |
|
|
"northeast thailand": "Thailand", "east myanmar": "Myanmar", "west thailand": "Thailand", |
|
|
"central india": "India", "east india": "India", "northeast india": "India", |
|
|
"south sibera": "Russia", "siberia": "Russia", "yunnan": "China", |
|
|
"sumatra": "Indonesia", "borneo": "Indonesia", |
|
|
"northern mindanao": "Philippines", "west malaysia": "Malaysia", |
|
|
"mongolia": "China", |
|
|
"beijing": "China", |
|
|
"north laos": "Laos", "central laos": "Laos", |
|
|
"east myanmar": "Myanmar", "west myanmar": "Myanmar"} |
|
|
|
|
|
|
|
|
def get_country_from_text(text): |
|
|
text_lower = text.lower() |
|
|
|
|
|
|
|
|
for country in pycountry.countries: |
|
|
|
|
|
if text_lower == country.name.lower(): |
|
|
return country.name |
|
|
|
|
|
|
|
|
if hasattr(country, 'common_name') and text_lower == country.common_name.lower(): |
|
|
return country.common_name |
|
|
|
|
|
|
|
|
if hasattr(country, 'official_name') and text_lower == country.official_name.lower(): |
|
|
return country.official_name |
|
|
|
|
|
|
|
|
if country.name.lower() in text_lower: |
|
|
return country.name |
|
|
|
|
|
|
|
|
if hasattr(country, 'common_name') and country.common_name.lower() in text_lower: |
|
|
return country.common_name |
|
|
|
|
|
for keyword, country in country_keywords_regional_overrides.items(): |
|
|
if keyword in text_lower: |
|
|
return country |
|
|
|
|
|
if "north asia" in text_lower or "southeast asia" in text_lower or "east asia" in text_lower: |
|
|
return "unknown" |
|
|
|
|
|
return "unknown" |
|
|
|
|
|
|
|
|
non_meaningful_pop_names = set(stopwords.words('english')) |
|
|
|
|
|
def parse_population_code_to_country(plain_text_content, table_strings): |
|
|
pop_code_country_map = {} |
|
|
pop_code_ethnicity_map = {} |
|
|
pop_code_specific_loc_map = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pop_info_pattern = re.compile( |
|
|
r'([A-Za-z\s]+?)\s+([A-Z]+\d*)\s+' |
|
|
r'([A-Za-z\s\(\)\-,\/]+?)\s+' |
|
|
r'(North+|South+|West+|East+|Thailand|Laos|Cambodia|Myanmar|Philippines|Indonesia|Malaysia|China|India|Taiwan|Vietnam|Russia|Nepal|Japan|South Korea)\b' |
|
|
r'(?:.*?([A-Za-z\s\-]+))?\s*' |
|
|
r'(\d+(?:\s+\d+\.?\d*)*)?', |
|
|
re.IGNORECASE |
|
|
) |
|
|
for table_str in table_strings: |
|
|
table_data = parse_literal_python_list(table_str) |
|
|
if table_data: |
|
|
is_list_of_lists = bool(table_data) and isinstance(table_data[0], list) |
|
|
if is_list_of_lists: |
|
|
for row_idx, row in enumerate(table_data): |
|
|
row_text = " ".join(map(str, row)) |
|
|
match = pop_info_pattern.search(row_text) |
|
|
if match: |
|
|
pop_name = match.group(1).strip() |
|
|
pop_code = match.group(2).upper() |
|
|
specific_loc_text = match.group(3).strip() |
|
|
country_text = match.group(4).strip() |
|
|
linguistic_family = match.group(5).strip() if match.group(5) else 'unknown' |
|
|
|
|
|
final_country = get_country_from_text(country_text) |
|
|
if final_country == 'unknown': |
|
|
final_country = get_country_from_text(specific_loc_text) |
|
|
|
|
|
if pop_code: |
|
|
pop_code_country_map[pop_code] = final_country |
|
|
|
|
|
|
|
|
pop_code_ethnicity_map[pop_code] = pop_name |
|
|
|
|
|
|
|
|
pop_code_specific_loc_map[pop_code] = specific_loc_text |
|
|
else: |
|
|
row_text = " ".join(map(str, table_data)) |
|
|
match = pop_info_pattern.search(row_text) |
|
|
if match: |
|
|
pop_name = match.group(1).strip() |
|
|
pop_code = match.group(2).upper() |
|
|
specific_loc_text = match.group(3).strip() |
|
|
country_text = match.group(4).strip() |
|
|
linguistic_family = match.group(5).strip() if match.group(5) else 'unknown' |
|
|
|
|
|
final_country = get_country_from_text(country_text) |
|
|
if final_country == 'unknown': |
|
|
final_country = get_country_from_text(specific_loc_text) |
|
|
|
|
|
if pop_code: |
|
|
pop_code_country_map[pop_code] = final_country |
|
|
|
|
|
|
|
|
pop_code_ethnicity_map[pop_code] = pop_name |
|
|
|
|
|
|
|
|
pop_code_specific_loc_map[pop_code] = specific_loc_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sentences = data_preprocess.extract_sentences(plain_text_content) |
|
|
for s in sentences: |
|
|
|
|
|
matches = pop_info_pattern.finditer(s) |
|
|
pop_name, pop_code, specific_loc_text, country_text = "unknown", "unknown", "unknown", "unknown" |
|
|
for match in matches: |
|
|
if match.group(1): |
|
|
pop_name = match.group(1).strip() |
|
|
if match.group(2): |
|
|
pop_code = match.group(2).upper() |
|
|
if match.group(3): |
|
|
specific_loc_text = match.group(3).strip() |
|
|
if match.group(4): |
|
|
country_text = match.group(4).strip() |
|
|
|
|
|
|
|
|
final_country = get_country_from_text(country_text) |
|
|
if final_country == 'unknown': |
|
|
final_country = get_country_from_text(specific_loc_text) |
|
|
|
|
|
if pop_code.lower() not in non_meaningful_pop_names: |
|
|
if final_country.lower() not in non_meaningful_pop_names: |
|
|
pop_code_country_map[pop_code] = final_country |
|
|
if pop_name.lower() not in non_meaningful_pop_names: |
|
|
pop_code_ethnicity_map[pop_code] = pop_name |
|
|
if specific_loc_text.lower() not in non_meaningful_pop_names: |
|
|
pop_code_specific_loc_map[pop_code] = specific_loc_text |
|
|
|
|
|
|
|
|
if pop_name.lower() == "khon mueang": |
|
|
pop_code_ethnicity_map[pop_code] = "Khon Mueang" |
|
|
elif pop_name.lower() == "lawa": |
|
|
pop_code_ethnicity_map[pop_code] = "Lawa" |
|
|
elif pop_name.lower() == "mon": |
|
|
pop_code_ethnicity_map[pop_code] = "Mon" |
|
|
elif pop_name.lower() == "seak": |
|
|
pop_code_ethnicity_map[pop_code] = "Seak" |
|
|
elif pop_name.lower() == "nyaw": |
|
|
pop_code_ethnicity_map[pop_code] = "Nyaw" |
|
|
elif pop_name.lower() == "nyahkur": |
|
|
pop_code_ethnicity_map[pop_code] = "Nyahkur" |
|
|
elif pop_name.lower() == "suay": |
|
|
pop_code_ethnicity_map[pop_code] = "Suay" |
|
|
elif pop_name.lower() == "soa": |
|
|
pop_code_ethnicity_map[pop_code] = "Soa" |
|
|
elif pop_name.lower() == "bru": |
|
|
pop_code_ethnicity_map[pop_code] = "Bru" |
|
|
elif pop_name.lower() == "khamu": |
|
|
pop_code_ethnicity_map[pop_code] = "Khamu" |
|
|
|
|
|
return pop_code_country_map, pop_code_ethnicity_map, pop_code_specific_loc_map |
|
|
|
|
|
def general_parse_population_code_to_country(plain_text_content, table_strings): |
|
|
pop_code_country_map = {} |
|
|
pop_code_ethnicity_map = {} |
|
|
pop_code_specific_loc_map = {} |
|
|
sample_id_to_pop_code = {} |
|
|
|
|
|
for table_str in table_strings: |
|
|
table_data = parse_literal_python_list(table_str) |
|
|
if not table_data or not isinstance(table_data[0], list): |
|
|
continue |
|
|
|
|
|
header_row = [col.lower() for col in table_data[0]] |
|
|
header_map = {col: idx for idx, col in enumerate(header_row)} |
|
|
|
|
|
|
|
|
if 'id' in header_map and 'country' in header_map: |
|
|
for row in table_strings[1:]: |
|
|
row = parse_literal_python_list(row)[0] |
|
|
if len(row) < len(header_row): |
|
|
continue |
|
|
pop_code = str(row[header_map['id']]).strip() |
|
|
country = str(row[header_map['country']]).strip() |
|
|
province = row[header_map['province']].strip() if 'province' in header_map else 'unknown' |
|
|
pop_group = row[header_map['population group / region']].strip() if 'population group / region' in header_map else 'unknown' |
|
|
pop_code_country_map[pop_code] = country |
|
|
pop_code_specific_loc_map[pop_code] = province |
|
|
pop_code_ethnicity_map[pop_code] = pop_group |
|
|
|
|
|
|
|
|
elif 'sample id' in header_map and 'population code' in header_map: |
|
|
for row in table_strings[1:]: |
|
|
row = parse_literal_python_list(row)[0] |
|
|
if len(row) < 2: |
|
|
continue |
|
|
sample_id = row[header_map['sample id']].strip().upper() |
|
|
pop_code = row[header_map['population code']].strip().upper() |
|
|
sample_id_to_pop_code[sample_id] = pop_code |
|
|
|
|
|
|
|
|
elif 'population code' in header_map and 'country' in header_map: |
|
|
for row in table_strings[1:]: |
|
|
row = parse_literal_python_list(row)[0] |
|
|
if len(row) < 2: |
|
|
continue |
|
|
pop_code = row[header_map['population code']].strip().upper() |
|
|
country = row[header_map['country']].strip() |
|
|
pop_code_country_map[pop_code] = country |
|
|
|
|
|
return pop_code_country_map, pop_code_ethnicity_map, pop_code_specific_loc_map, sample_id_to_pop_code |
|
|
|
|
|
def chunk_text(text, chunk_size=500, overlap=50): |
|
|
"""Splits text into chunks (by words) with overlap.""" |
|
|
chunks = [] |
|
|
words = text.split() |
|
|
num_words = len(words) |
|
|
|
|
|
start = 0 |
|
|
while start < num_words: |
|
|
end = min(start + chunk_size, num_words) |
|
|
chunk = " ".join(words[start:end]) |
|
|
chunks.append(chunk) |
|
|
|
|
|
if end == num_words: |
|
|
break |
|
|
start += chunk_size - overlap |
|
|
return chunks |
|
|
|
|
|
def build_vector_index_and_data(doc_path, index_path="faiss_index.bin", chunks_path="document_chunks.json", structured_path="structured_lookup.json"): |
|
|
""" |
|
|
Reads document, builds structured lookup, chunks remaining text, embeds chunks, |
|
|
and builds/saves a FAISS index. |
|
|
""" |
|
|
print("Step 1: Reading document and extracting structured data...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plain_text_content, table_strings, document_title = read_docx_text(doc_path) |
|
|
pop_code_to_country, pop_code_to_ethnicity, pop_code_to_specific_loc, sample_id_map = general_parse_population_code_to_country(plain_text_content, table_strings) |
|
|
|
|
|
final_structured_entries = {} |
|
|
if sample_id_map: |
|
|
for sample_id, pop_code in sample_id_map.items(): |
|
|
country = pop_code_to_country.get(pop_code, 'unknown') |
|
|
ethnicity = pop_code_to_ethnicity.get(pop_code, 'unknown') |
|
|
specific_loc = pop_code_to_specific_loc.get(pop_code, 'unknown') |
|
|
final_structured_entries[sample_id] = { |
|
|
'population_code': pop_code, |
|
|
'country': country, |
|
|
'type': 'modern', |
|
|
'ethnicity': ethnicity, |
|
|
'specific_location': specific_loc |
|
|
} |
|
|
else: |
|
|
for pop_code in pop_code_to_country.keys(): |
|
|
country = pop_code_to_country.get(pop_code, 'unknown') |
|
|
ethnicity = pop_code_to_ethnicity.get(pop_code, 'unknown') |
|
|
specific_loc = pop_code_to_specific_loc.get(pop_code, 'unknown') |
|
|
final_structured_entries[pop_code] = { |
|
|
'population_code': pop_code, |
|
|
'country': country, |
|
|
'type': 'modern', |
|
|
'ethnicity': ethnicity, |
|
|
'specific_location': specific_loc |
|
|
} |
|
|
if not final_structured_entries: |
|
|
|
|
|
sample_id_map, contiguous_ranges_data = parse_sample_id_to_population_code(plain_text_content) |
|
|
pop_code_to_country, pop_code_to_ethnicity, pop_code_to_specific_loc = parse_population_code_to_country(plain_text_content, table_strings) |
|
|
if sample_id_map: |
|
|
for sample_id, pop_code in sample_id_map.items(): |
|
|
country = pop_code_to_country.get(pop_code, 'unknown') |
|
|
ethnicity = pop_code_to_ethnicity.get(pop_code, 'unknown') |
|
|
specific_loc = pop_code_to_specific_loc.get(pop_code, 'unknown') |
|
|
final_structured_entries[sample_id] = { |
|
|
'population_code': pop_code, |
|
|
'country': country, |
|
|
'type': 'modern', |
|
|
'ethnicity': ethnicity, |
|
|
'specific_location': specific_loc |
|
|
} |
|
|
else: |
|
|
for pop_code in pop_code_to_country.keys(): |
|
|
country = pop_code_to_country.get(pop_code, 'unknown') |
|
|
ethnicity = pop_code_to_ethnicity.get(pop_code, 'unknown') |
|
|
specific_loc = pop_code_to_specific_loc.get(pop_code, 'unknown') |
|
|
final_structured_entries[pop_code] = { |
|
|
'population_code': pop_code, |
|
|
'country': country, |
|
|
'type': 'modern', |
|
|
'ethnicity': ethnicity, |
|
|
'specific_location': specific_loc |
|
|
} |
|
|
|
|
|
master_lookup = { |
|
|
'document_title': document_title, |
|
|
'pop_code_to_country': pop_code_to_country, |
|
|
'pop_code_to_ethnicity': pop_code_to_ethnicity, |
|
|
'pop_code_to_specific_loc': pop_code_to_specific_loc, |
|
|
'sample_id_map': sample_id_map, |
|
|
'final_structured_entries': final_structured_entries |
|
|
} |
|
|
print(f"Structured lookup built with {len(final_structured_entries)} entries in 'final_structured_entries'.") |
|
|
|
|
|
with open(structured_path, 'w') as f: |
|
|
json.dump(master_lookup, f, indent=4) |
|
|
print(f"Structured lookup saved to {structured_path}.") |
|
|
|
|
|
print("Step 2: Chunking document for RAG vector index...") |
|
|
|
|
|
clean_text, clean_table = "", "" |
|
|
if plain_text_content: |
|
|
clean_text = data_preprocess.normalize_for_overlap(plain_text_content) |
|
|
if table_strings: |
|
|
clean_table = data_preprocess.normalize_for_overlap(". ".join(table_strings)) |
|
|
all_clean_chunk = clean_text + clean_table |
|
|
document_chunks = chunk_text(all_clean_chunk) |
|
|
print(f"Document chunked into {len(document_chunks)} chunks.") |
|
|
|
|
|
print("Step 3: Generating embeddings for chunks (this might take time and cost API calls)...") |
|
|
|
|
|
embedding_model_for_chunks = genai.GenerativeModel('models/text-embedding-004') |
|
|
|
|
|
chunk_embeddings = [] |
|
|
for i, chunk in enumerate(document_chunks): |
|
|
embedding = get_embedding(chunk, task_type="RETRIEVAL_DOCUMENT") |
|
|
if embedding is not None and embedding.shape[0] > 0: |
|
|
chunk_embeddings.append(embedding) |
|
|
else: |
|
|
print(f"Warning: Failed to get valid embedding for chunk {i}. Skipping.") |
|
|
chunk_embeddings.append(np.zeros(768, dtype='float32')) |
|
|
|
|
|
if not chunk_embeddings: |
|
|
raise ValueError("No valid embeddings generated. Check get_embedding function and API.") |
|
|
|
|
|
embedding_dimension = chunk_embeddings[0].shape[0] |
|
|
index = faiss.IndexFlatL2(embedding_dimension) |
|
|
index.add(np.array(chunk_embeddings)) |
|
|
|
|
|
faiss.write_index(index, index_path) |
|
|
with open(chunks_path, "w") as f: |
|
|
json.dump(document_chunks, f) |
|
|
|
|
|
print(f"FAISS index built and saved to {index_path}.") |
|
|
print(f"Document chunks saved to {chunks_path}.") |
|
|
return master_lookup, index, document_chunks, all_clean_chunk |
|
|
|
|
|
|
|
|
def load_rag_assets(index_path="faiss_index.bin", chunks_path="document_chunks.json", structured_path="structured_lookup.json"): |
|
|
"""Loads pre-built RAG assets (FAISS index, chunks, structured lookup).""" |
|
|
print("Loading RAG assets...") |
|
|
master_structured_lookup = {} |
|
|
if os.path.exists(structured_path): |
|
|
with open(structured_path, 'r') as f: |
|
|
master_structured_lookup = json.load(f) |
|
|
print("Structured lookup loaded.") |
|
|
else: |
|
|
print("Structured lookup file not found. Rebuilding is likely needed.") |
|
|
|
|
|
index = None |
|
|
chunks = [] |
|
|
if os.path.exists(index_path) and os.path.exists(chunks_path): |
|
|
try: |
|
|
index = faiss.read_index(index_path) |
|
|
with open(chunks_path, "r") as f: |
|
|
chunks = json.load(f) |
|
|
print("FAISS index and chunks loaded.") |
|
|
except Exception as e: |
|
|
print(f"Error loading FAISS index or chunks: {e}. Will rebuild.") |
|
|
index = None |
|
|
chunks = [] |
|
|
else: |
|
|
print("FAISS index or chunks files not found.") |
|
|
|
|
|
return master_structured_lookup, index, chunks |
|
|
|
|
|
def exactInContext(text, keyword): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text = text.lower() |
|
|
idx = text.find(keyword.lower()) |
|
|
if idx == -1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return False |
|
|
return True |
|
|
def chooseContextLLM(contexts, kw): |
|
|
|
|
|
for con in contexts: |
|
|
context = contexts[con] |
|
|
if context: |
|
|
if exactInContext(context, kw): |
|
|
return con, context |
|
|
|
|
|
if contexts["all_output"]: |
|
|
return "all_output", contexts["all_output"] |
|
|
else: |
|
|
|
|
|
|
|
|
if contexts["chunk"]: return "chunk", contexts["chunk"] |
|
|
elif contexts["document_chunk"]: return "document_chunk", contexts["document_chunk"] |
|
|
else: return None, None |
|
|
def clean_llm_output(llm_response_text, output_format_str): |
|
|
results = [] |
|
|
lines = llm_response_text.strip().split('\n') |
|
|
output_country, output_type, output_ethnicity, output_specific_location = [],[],[],[] |
|
|
for line in lines: |
|
|
extracted_country, extracted_type, extracted_ethnicity, extracted_specific_location = "unknown", "unknown", "unknown", "unknown" |
|
|
line = line.strip() |
|
|
if output_format_str == "ethnicity, specific_location/unknown": |
|
|
parsed_output = re.search(r'^\s*([^,]+?),\s*(.+?)\s*$', llm_response_text) |
|
|
if parsed_output: |
|
|
extracted_ethnicity = parsed_output.group(1).strip() |
|
|
extracted_specific_location = parsed_output.group(2).strip() |
|
|
else: |
|
|
print(" DEBUG: LLM did not follow expected 2-field format for targeted RAG. Defaulting to unknown for ethnicity/specific_location.") |
|
|
extracted_ethnicity = 'unknown' |
|
|
extracted_specific_location = 'unknown' |
|
|
elif output_format_str == "modern/ancient/unknown, ethnicity, specific_location/unknown": |
|
|
parsed_output = re.search(r'^\s*([^,]+?),\s*([^,]+?),\s*(.+?)\s*$', llm_response_text) |
|
|
if parsed_output: |
|
|
extracted_type = parsed_output.group(1).strip() |
|
|
extracted_ethnicity = parsed_output.group(2).strip() |
|
|
extracted_specific_location = parsed_output.group(3).strip() |
|
|
else: |
|
|
|
|
|
parsed_output_2_fields = re.search(r'^\s*([^,]+?),\s*([^,]+?)\s*$', llm_response_text) |
|
|
if parsed_output_2_fields: |
|
|
extracted_type = parsed_output_2_fields.group(1).strip() |
|
|
extracted_ethnicity = parsed_output_2_fields.group(2).strip() |
|
|
extracted_specific_location = 'unknown' |
|
|
else: |
|
|
|
|
|
parsed_output_1_field = re.search(r'^\s*([^,]+?)\s*$', llm_response_text) |
|
|
if parsed_output_1_field: |
|
|
extracted_type = parsed_output_1_field.group(1).strip() |
|
|
extracted_ethnicity = 'unknown' |
|
|
extracted_specific_location = 'unknown' |
|
|
else: |
|
|
print(" DEBUG: LLM did not follow any expected simplified format. Attempting verbose parsing fallback.") |
|
|
type_match_fallback = re.search(r'Type:\s*([A-Za-z\s-]+)', llm_response_text) |
|
|
extracted_type = type_match_fallback.group(1).strip() if type_match_fallback else 'unknown' |
|
|
extracted_ethnicity = 'unknown' |
|
|
extracted_specific_location = 'unknown' |
|
|
else: |
|
|
parsed_output = re.search(r'^\s*([^,]+?),\s*([^,]+?),\s*([^,]+?),\s*(.+?)\s*$', line) |
|
|
if parsed_output: |
|
|
extracted_country = parsed_output.group(1).strip() |
|
|
extracted_type = parsed_output.group(2).strip() |
|
|
extracted_ethnicity = parsed_output.group(3).strip() |
|
|
extracted_specific_location = parsed_output.group(4).strip() |
|
|
else: |
|
|
print(f" DEBUG: Line did not follow expected 4-field format: {line}") |
|
|
parsed_output_2_fields = re.search(r'^\s*([^,]+?),\s*([^,]+?)\s*$', line) |
|
|
if parsed_output_2_fields: |
|
|
extracted_country = parsed_output_2_fields.group(1).strip() |
|
|
extracted_type = parsed_output_2_fields.group(2).strip() |
|
|
extracted_ethnicity = 'unknown' |
|
|
extracted_specific_location = 'unknown' |
|
|
else: |
|
|
print(f" DEBUG: Fallback to verbose-style parsing: {line}") |
|
|
country_match_fallback = re.search(r'Country:\s*([A-Za-z\s-]+)', line) |
|
|
type_match_fallback = re.search(r'Type:\s*([A-Za-z\s-]+)', line) |
|
|
extracted_country = country_match_fallback.group(1).strip() if country_match_fallback else 'unknown' |
|
|
extracted_type = type_match_fallback.group(1).strip() if type_match_fallback else 'unknown' |
|
|
extracted_ethnicity = 'unknown' |
|
|
extracted_specific_location = 'unknown' |
|
|
|
|
|
results.append({ |
|
|
"country": extracted_country, |
|
|
"type": extracted_type, |
|
|
"ethnicity": extracted_ethnicity, |
|
|
"specific_location": extracted_specific_location |
|
|
|
|
|
|
|
|
}) |
|
|
|
|
|
if output_format_str == "ethnicity, specific_location/unknown": |
|
|
for result in results: |
|
|
if result["ethnicity"] not in output_ethnicity: |
|
|
output_ethnicity.append(result["ethnicity"]) |
|
|
if result["specific_location"] not in output_specific_location: |
|
|
output_specific_location.append(result["specific_location"]) |
|
|
return " or ".join(output_ethnicity), " or ".join(output_specific_location) |
|
|
elif output_format_str == "modern/ancient/unknown, ethnicity, specific_location/unknown": |
|
|
for result in results: |
|
|
if result["type"] not in output_type: |
|
|
output_type.append(result["type"]) |
|
|
if result["ethnicity"] not in output_ethnicity: |
|
|
output_ethnicity.append(result["ethnicity"]) |
|
|
if result["specific_location"] not in output_specific_location: |
|
|
output_specific_location.append(result["specific_location"]) |
|
|
|
|
|
return " or ".join(output_type)," or ".join(output_ethnicity), " or ".join(output_specific_location) |
|
|
else: |
|
|
for result in results: |
|
|
if result["country"] not in output_country: |
|
|
output_country.append(result["country"]) |
|
|
if result["type"] not in output_type: |
|
|
output_type.append(result["type"]) |
|
|
if result["ethnicity"] not in output_ethnicity: |
|
|
output_ethnicity.append(result["ethnicity"]) |
|
|
if result["specific_location"] not in output_specific_location: |
|
|
output_specific_location.append(result["specific_location"]) |
|
|
return " or ".join(output_country)," or ".join(output_type)," or ".join(output_ethnicity), " or ".join(output_specific_location) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_multi_sample_llm_output(raw_response: str, output_format_str): |
|
|
""" |
|
|
Parse LLM output with possibly multiple metadata lines + shared explanations. |
|
|
""" |
|
|
metadata_list = {} |
|
|
explanation_lines = [] |
|
|
output_answers = raw_response.split("\n")[0].split(", ") |
|
|
explanation_lines = [x for x in raw_response.split("\n")[1:] if x.strip()] |
|
|
print("raw explanation line which split by new line: ", explanation_lines) |
|
|
if len(explanation_lines) == 1: |
|
|
if len(explanation_lines[0].split(". ")) > len(explanation_lines): |
|
|
explanation_lines = [x for x in explanation_lines[0].split(". ") if x.strip()] |
|
|
print("explain line split by dot: ", explanation_lines) |
|
|
output_formats = output_format_str.split(", ") |
|
|
explain = "" |
|
|
|
|
|
if output_format_str: |
|
|
outputs = output_format_str.split(", ") |
|
|
for o in range(len(outputs)): |
|
|
output = outputs[o] |
|
|
metadata_list[output] = {"answer":"", |
|
|
output+"_explanation":""} |
|
|
|
|
|
if o < len(output_answers): |
|
|
|
|
|
|
|
|
try: |
|
|
if ": " in output_answers[o]: |
|
|
output_answers[o] = output_answers[o].split(": ")[1] |
|
|
except: |
|
|
pass |
|
|
|
|
|
metadata_list[output]["answer"] = output_answers[o] |
|
|
if "unknown" in metadata_list[output]["answer"].lower(): |
|
|
metadata_list[output]["answer"] = "unknown" |
|
|
else: |
|
|
metadata_list[output]["answer"] = "unknown" |
|
|
|
|
|
if metadata_list[output]["answer"] != "unknown": |
|
|
if explanation_lines: |
|
|
explain = explanation_lines.pop(0) |
|
|
metadata_list[output][output+"_explanation"] = explain |
|
|
else: |
|
|
metadata_list[output][output+"_explanation"] = "unknown" |
|
|
return metadata_list |
|
|
|
|
|
def merge_metadata_outputs(metadata_list): |
|
|
""" |
|
|
Merge a list of metadata dicts into one, combining differing values with 'or'. |
|
|
Assumes all dicts have the same keys. |
|
|
""" |
|
|
if not metadata_list: |
|
|
return {} |
|
|
|
|
|
merged = {} |
|
|
keys = metadata_list[0].keys() |
|
|
|
|
|
for key in keys: |
|
|
values = [md[key] for md in metadata_list if key in md] |
|
|
unique_values = list(dict.fromkeys(values)) |
|
|
if "unknown" in unique_values: |
|
|
unique_values.pop(unique_values.index("unknown")) |
|
|
if len(unique_values) == 1: |
|
|
merged[key] = unique_values[0] |
|
|
else: |
|
|
merged[key] = " or ".join(unique_values) |
|
|
|
|
|
return merged |
|
|
|
|
|
|
|
|
def query_document_info(query_word, alternative_query_word, metadata, master_structured_lookup, faiss_index, document_chunks, llm_api_function, chunk=None, all_output=None, model_ai=None): |
|
|
""" |
|
|
Queries the document using a hybrid approach: |
|
|
1. Local structured lookup (fast, cheap, accurate for known patterns). |
|
|
2. RAG with semantic search and LLM (general, flexible, cost-optimized). |
|
|
""" |
|
|
if model_ai: |
|
|
if model_ai == "gemini-1.5-flash-latest": |
|
|
genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) |
|
|
PRICE_PER_1K_INPUT_LLM = 0.000075 |
|
|
PRICE_PER_1K_OUTPUT_LLM = 0.0003 |
|
|
PRICE_PER_1K_EMBEDDING_INPUT = 0.000025 |
|
|
global_llm_model_for_counting_tokens = genai.GenerativeModel("gemini-1.5-flash-latest") |
|
|
else: |
|
|
genai.configure(api_key=os.getenv("GOOGLE_API_KEY_BACKUP")) |
|
|
|
|
|
PRICE_PER_1K_INPUT_LLM = 0.00010 |
|
|
PRICE_PER_1K_OUTPUT_LLM = 0.00040 |
|
|
|
|
|
|
|
|
PRICE_PER_1K_EMBEDDING_INPUT = 0.00015 |
|
|
global_llm_model_for_counting_tokens = genai.GenerativeModel("gemini-2.5-flash-lite") |
|
|
|
|
|
if metadata: |
|
|
extracted_country, extracted_specific_location, extracted_ethnicity, extracted_type = metadata["country"], metadata["specific_location"], metadata["ethnicity"], metadata["sample_type"] |
|
|
extracted_col_date, extracted_iso, extracted_title, extracted_features = metadata["collection_date"], metadata["isolate"], metadata["title"], metadata["all_features"] |
|
|
else: |
|
|
extracted_country, extracted_specific_location, extracted_ethnicity, extracted_type = "unknown", "unknown", "unknown", "unknown" |
|
|
extracted_col_date, extracted_iso, extracted_title = "unknown", "unknown", "unknown" |
|
|
|
|
|
if alternative_query_word: |
|
|
alternative_query_word_cleaned = alternative_query_word.split('.')[0] |
|
|
else: |
|
|
alternative_query_word_cleaned = alternative_query_word |
|
|
country_explanation, sample_type_explanation = None, None |
|
|
|
|
|
|
|
|
final_structured_entries = master_structured_lookup.get('final_structured_entries', {}) |
|
|
document_title = master_structured_lookup.get('document_title', 'Unknown Document Title') |
|
|
|
|
|
|
|
|
method_used = 'unknown' |
|
|
population_code_from_sl = 'unknown' |
|
|
total_query_cost = 0 |
|
|
|
|
|
try: |
|
|
print("try attempt 1 in model query") |
|
|
structured_info = final_structured_entries.get(query_word.upper()) |
|
|
if structured_info: |
|
|
if extracted_country == 'unknown': |
|
|
extracted_country = structured_info['country'] |
|
|
if extracted_type == 'unknown': |
|
|
extracted_type = structured_info['type'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
population_code_from_sl = structured_info['population_code'] |
|
|
method_used = "structured_lookup_direct" |
|
|
print(f"'{query_word}' found in structured lookup (direct match).") |
|
|
except: |
|
|
print("pass attempt 1 in model query") |
|
|
pass |
|
|
|
|
|
try: |
|
|
print("try attempt 2 in model query") |
|
|
if method_used == 'unknown': |
|
|
query_prefix, query_num_str = _parse_individual_code_parts(query_word) |
|
|
if query_prefix is not None and query_num_str is not None: |
|
|
try: query_num = int(query_num_str) |
|
|
except ValueError: query_num = None |
|
|
if query_num is not None: |
|
|
query_prefix_upper = query_prefix.upper() |
|
|
contiguous_ranges = master_structured_lookup.get('contiguous_ranges', defaultdict(list)) |
|
|
pop_code_to_country = master_structured_lookup.get('pop_code_to_country', {}) |
|
|
pop_code_to_ethnicity = master_structured_lookup.get('pop_code_to_ethnicity', {}) |
|
|
pop_code_to_specific_loc = master_structured_lookup.get('pop_code_to_specific_loc', {}) |
|
|
|
|
|
if query_prefix_upper in contiguous_ranges: |
|
|
for start_num, end_num, pop_code_for_range in contiguous_ranges[query_prefix_upper]: |
|
|
if start_num <= query_num <= end_num: |
|
|
country_from_heuristic = pop_code_to_country.get(pop_code_for_range, 'unknown') |
|
|
if country_from_heuristic != 'unknown': |
|
|
if extracted_country == 'unknown': |
|
|
extracted_country = country_from_heuristic |
|
|
if extracted_type == 'unknown': |
|
|
extracted_type = 'modern' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
population_code_from_sl = pop_code_for_range |
|
|
method_used = "structured_lookup_heuristic_range_match" |
|
|
print(f"'{query_word}' not direct. Heuristic: Falls within range {query_prefix_upper}{start_num}-{query_prefix_upper}{end_num}.") |
|
|
break |
|
|
else: |
|
|
print(f"'{query_word}' heuristic match found, but country unknown. Will fall to RAG below.") |
|
|
except: |
|
|
print("pass attempt 2 in model query") |
|
|
pass |
|
|
|
|
|
try: |
|
|
print("try attempt 3 in model query") |
|
|
if method_used == 'unknown' and alternative_query_word_cleaned and alternative_query_word_cleaned != query_word: |
|
|
print(f"'{query_word}' not found in structured (or heuristic). Trying alternative '{alternative_query_word_cleaned}'.") |
|
|
|
|
|
|
|
|
structured_info_alt = final_structured_entries.get(alternative_query_word_cleaned.upper()) |
|
|
if structured_info_alt: |
|
|
if extracted_country == 'unknown': |
|
|
extracted_country = structured_info_alt['country'] |
|
|
if extracted_type == 'unknown': |
|
|
extracted_type = structured_info_alt['type'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
population_code_from_sl = structured_info_alt['population_code'] |
|
|
method_used = "structured_lookup_alt_direct" |
|
|
print(f"Alternative '{alternative_query_word_cleaned}' found in structured lookup (direct match).") |
|
|
else: |
|
|
|
|
|
alt_prefix, alt_num_str = _parse_individual_code_parts(alternative_query_word_cleaned) |
|
|
if alt_prefix is not None and alt_num_str is not None: |
|
|
try: alt_num = int(alt_num_str) |
|
|
except ValueError: alt_num = None |
|
|
if alt_num is not None: |
|
|
alt_prefix_upper = alt_prefix.upper() |
|
|
contiguous_ranges = master_structured_lookup.get('contiguous_ranges', defaultdict(list)) |
|
|
pop_code_to_country = master_structured_lookup.get('pop_code_to_country', {}) |
|
|
pop_code_to_ethnicity = master_structured_lookup.get('pop_code_to_ethnicity', {}) |
|
|
pop_code_to_specific_loc = master_structured_lookup.get('pop_code_to_specific_loc', {}) |
|
|
if alt_prefix_upper in contiguous_ranges: |
|
|
for start_num, end_num, pop_code_for_range in contiguous_ranges[alt_prefix_upper]: |
|
|
if start_num <= alt_num <= end_num: |
|
|
country_from_heuristic_alt = pop_code_to_country.get(pop_code_for_range, 'unknown') |
|
|
if country_from_heuristic_alt != 'unknown': |
|
|
if extracted_country == 'unknown': |
|
|
extracted_country = country_from_heuristic_alt |
|
|
if extracted_type == 'unknown': |
|
|
extracted_type = 'modern' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
population_code_from_sl = pop_code_for_range |
|
|
method_used = "structured_lookup_alt_heuristic_range_match" |
|
|
break |
|
|
else: |
|
|
print(f"Alternative '{alternative_query_word_cleaned}' heuristic match found, but country unknown. Will fall to RAG below.") |
|
|
except: |
|
|
print("pass attempt 3 in model query") |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_rag = True |
|
|
if run_rag: |
|
|
print("try run rag") |
|
|
|
|
|
rag_query_phrase = f"'{query_word}'" |
|
|
if alternative_query_word_cleaned and alternative_query_word_cleaned != query_word: |
|
|
rag_query_phrase += f" or its alternative word '{alternative_query_word_cleaned}'" |
|
|
|
|
|
|
|
|
semantic_query_for_embedding = rag_query_phrase |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt_instruction_prefix = "" |
|
|
output_format_str = "" |
|
|
|
|
|
|
|
|
is_full_rag_scenario = True |
|
|
|
|
|
if is_full_rag_scenario: |
|
|
output_format_str = "country_name, modern/ancient/unknown" |
|
|
method_used = "rag_llm" |
|
|
print(f"Proceeding to FULL RAG for {rag_query_phrase}.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_embedding_cost = 0 |
|
|
try: |
|
|
query_embedding_vector = get_embedding(semantic_query_for_embedding, task_type="RETRIEVAL_QUERY") |
|
|
query_embedding_tokens = global_llm_model_for_counting_tokens.count_tokens(semantic_query_for_embedding).total_tokens |
|
|
current_embedding_cost += (query_embedding_tokens / 1000) * PRICE_PER_1K_EMBEDDING_INPUT |
|
|
print(f" DEBUG: Query embedding tokens (for '{semantic_query_for_embedding}'): {query_embedding_tokens}, cost: ${current_embedding_cost:.6f}") |
|
|
|
|
|
if alternative_query_word_cleaned and alternative_query_word_cleaned != query_word: |
|
|
alt_embedding_vector = get_embedding(alternative_query_word_cleaned, task_type="RETRIEVAL_QUERY") |
|
|
alt_embedding_tokens = global_llm_model_for_counting_tokens.count_tokens(alternative_query_word_cleaned).total_tokens |
|
|
current_embedding_cost += (alt_embedding_tokens / 1000) * PRICE_PER_1K_EMBEDDING_INPUT |
|
|
print(f" DEBUG: Alternative query ('{alternative_query_word_cleaned}') embedding tokens: {alt_embedding_tokens}, cost: ${current_embedding_cost:.6f}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error getting query embedding for RAG: {e}") |
|
|
return extracted_country, extracted_type, "embedding_failed", extracted_ethnicity, extracted_specific_location, total_query_cost |
|
|
|
|
|
if query_embedding_vector is None or query_embedding_vector.shape[0] == 0: |
|
|
return extracted_country, extracted_type, "embedding_failed", extracted_ethnicity, extracted_specific_location, total_query_cost |
|
|
|
|
|
D, I = faiss_index.search(np.array([query_embedding_vector]), 4) |
|
|
|
|
|
retrieved_chunks_text = [] |
|
|
for idx in I[0]: |
|
|
if 0 <= idx < len(document_chunks): |
|
|
retrieved_chunks_text.append(document_chunks[idx]) |
|
|
|
|
|
context_for_llm = "" |
|
|
|
|
|
all_context = "\n".join(retrieved_chunks_text) |
|
|
listOfcontexts = {"chunk": chunk, |
|
|
"all_output": all_output, |
|
|
"document_chunk": all_context} |
|
|
label, context_for_llm = chooseContextLLM(listOfcontexts, query_word) |
|
|
if not context_for_llm: |
|
|
label, context_for_llm = chooseContextLLM(listOfcontexts, alternative_query_word_cleaned) |
|
|
if not context_for_llm: |
|
|
context_for_llm = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + extracted_features |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(context_for_llm) > 1000*1000: |
|
|
context_for_llm = context_for_llm[:900000] |
|
|
|
|
|
|
|
|
|
|
|
features = metadata["all_features"] |
|
|
organism = "general" |
|
|
if features != "unknown": |
|
|
if "organism" in features: |
|
|
try: |
|
|
organism = features.split("organism: ")[1].split("\n")[0] |
|
|
except: |
|
|
organism = features.replace("\n","; ") |
|
|
explain_list = "country or sample type (modern/ancient)" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt_for_llm = ( |
|
|
f"{prompt_instruction_prefix}" |
|
|
f"Given the following text snippets, analyze the entity/concept {rag_query_phrase} " |
|
|
f"or the mitochondrial DNA sample in {organism} if these identifiers are not explicitly found. " |
|
|
f"Identify its **primary associated geographic location**, preferring the most specific available: " |
|
|
f"first try to determine the exact country; if no country is explicitly mentioned, then provide " |
|
|
f"the next most specific region, continent, island, or other clear geographic area mentioned. " |
|
|
f"If no geographic clues at all are present, state 'unknown' for location. " |
|
|
f"Also, determine if the genetic sample is from a 'modern' (present-day living individual) " |
|
|
f"or 'ancient' (prehistoric/archaeological) source. " |
|
|
f"If the text does not specify ancient or archaeological context, assume 'modern'. " |
|
|
f"Provide only {output_format_str}. " |
|
|
f"If any information is not explicitly present, use the fallback rules above before defaulting to 'unknown'. " |
|
|
f"For each non-'unknown' field in {explain_list}, write one sentence explaining how it was inferred from the text " |
|
|
f"(one sentence for each). " |
|
|
f"Format your answer so that:\n" |
|
|
f"1. The **first line** contains only the {output_format_str} answer.\n" |
|
|
f"2. The **second line onward** contains the explanations.\n" |
|
|
f"\nText Snippets:\n{context_for_llm}\n\n" |
|
|
f"Output Format Example:\nChina, modern, Daur, Heilongjiang province.\n" |
|
|
f"The text explicitly states \"chinese Daur ethnic group in Heilongjiang province\", indicating the country, " |
|
|
f"the ethnicity, and the specific province. The study is published in a journal, implying research on living " |
|
|
f"individuals, hence modern." |
|
|
) |
|
|
|
|
|
if model_ai: |
|
|
print("back up to ", model_ai) |
|
|
llm_response_text, model_instance = call_llm_api(prompt_for_llm, model=model_ai) |
|
|
else: |
|
|
print("still 2.5 flash gemini") |
|
|
llm_response_text, model_instance = call_llm_api(prompt_for_llm) |
|
|
print("\n--- DEBUG INFO FOR RAG ---") |
|
|
print("Retrieved Context Sent to LLM (first 500 chars):") |
|
|
print(context_for_llm[:500] + "..." if len(context_for_llm) > 500 else context_for_llm) |
|
|
print("\nRaw LLM Response:") |
|
|
print(llm_response_text) |
|
|
print("--- END DEBUG INFO ---") |
|
|
|
|
|
llm_cost = 0 |
|
|
if model_instance: |
|
|
try: |
|
|
input_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(prompt_for_llm).total_tokens |
|
|
output_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(llm_response_text).total_tokens |
|
|
print(f" DEBUG: LLM Input tokens: {input_llm_tokens}") |
|
|
print(f" DEBUG: LLM Output tokens: {output_llm_tokens}") |
|
|
llm_cost = (input_llm_tokens / 1000) * PRICE_PER_1K_INPUT_LLM + \ |
|
|
(output_llm_tokens / 1000) * PRICE_PER_1K_OUTPUT_LLM |
|
|
print(f" DEBUG: Estimated LLM cost: ${llm_cost:.6f}") |
|
|
except Exception as e: |
|
|
print(f" DEBUG: Error counting LLM tokens: {e}") |
|
|
llm_cost = 0 |
|
|
|
|
|
total_query_cost += current_embedding_cost + llm_cost |
|
|
print(f" DEBUG: Total estimated cost for this RAG query: ${total_query_cost:.6f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metadata_list = parse_multi_sample_llm_output(llm_response_text, output_format_str) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if output_format_str == "country_name, modern/ancient/unknown": |
|
|
outputs = output_format_str.split(", ") |
|
|
extracted_country, extracted_type = metadata_list[outputs[0]]["answer"], metadata_list[outputs[1]]["answer"] |
|
|
country_explanation,sample_type_explanation = metadata_list[outputs[0]][outputs[0]+"_explanation"], metadata_list[outputs[1]][outputs[1]+"_explanation"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return extracted_country, extracted_type, method_used, country_explanation, sample_type_explanation, total_query_cost |
|
|
|