VyLala commited on
Commit
6bcf9de
·
verified ·
1 Parent(s): ed9311e

Update mtdna_classifier.py

Browse files
Files changed (1) hide show
  1. mtdna_classifier.py +768 -763
mtdna_classifier.py CHANGED
@@ -1,764 +1,769 @@
1
- # mtDNA Location Classifier MVP (Google Colab)
2
- # Accepts accession number → Fetches PubMed ID + isolate name → Gets abstract → Predicts location
3
- import os
4
- #import streamlit as st
5
- import subprocess
6
- import re
7
- from Bio import Entrez
8
- import fitz
9
- import spacy
10
- from spacy.cli import download
11
- from NER.PDF import pdf
12
- from NER.WordDoc import wordDoc
13
- from NER.html import extractHTML
14
- from NER.word2Vec import word2vec
15
- from transformers import pipeline
16
- import urllib.parse, requests
17
- from pathlib import Path
18
- from upgradeClassify import filter_context_for_sample, infer_location_for_sample
19
-
20
- # Set your email (required by NCBI Entrez)
21
- #Entrez.email = "your-email@example.com"
22
- import nltk
23
-
24
- nltk.download("stopwords")
25
- nltk.download("punkt")
26
- nltk.download('punkt_tab')
27
- # Step 1: Get PubMed ID from Accession using EDirect
28
- from Bio import Entrez, Medline
29
- import re
30
-
31
- Entrez.email = "your_email@example.com"
32
-
33
- # --- Helper Functions (Re-organized and Upgraded) ---
34
-
35
- def fetch_ncbi_metadata(accession_number):
36
- """
37
- Fetches metadata directly from NCBI GenBank using Entrez.
38
- Includes robust error handling and improved field extraction.
39
- Prioritizes location extraction from geo_loc_name, then notes, then other qualifiers.
40
- Also attempts to extract ethnicity and sample_type (ancient/modern).
41
-
42
- Args:
43
- accession_number (str): The NCBI accession number (e.g., "ON792208").
44
-
45
- Returns:
46
- dict: A dictionary containing 'country', 'specific_location', 'ethnicity',
47
- 'sample_type', 'collection_date', 'isolate', 'title', 'doi', 'pubmed_id'.
48
- """
49
- Entrez.email = "your.email@example.com" # Required by NCBI, REPLACE WITH YOUR EMAIL
50
-
51
- country = "unknown"
52
- specific_location = "unknown"
53
- ethnicity = "unknown"
54
- sample_type = "unknown"
55
- collection_date = "unknown"
56
- isolate = "unknown"
57
- title = "unknown"
58
- doi = "unknown"
59
- pubmed_id = None
60
- all_feature = "unknown"
61
-
62
- KNOWN_COUNTRIES = [
63
- "Afghanistan", "Albania", "Algeria", "Andorra", "Angola", "Antigua and Barbuda", "Argentina", "Armenia", "Australia", "Austria", "Azerbaijan",
64
- "Bahamas", "Bahrain", "Bangladesh", "Barbados", "Belarus", "Belgium", "Belize", "Benin", "Bhutan", "Bolivia", "Bosnia and Herzegovina", "Botswana", "Brazil", "Brunei", "Bulgaria", "Burkina Faso", "Burundi",
65
- "Cabo Verde", "Cambodia", "Cameroon", "Canada", "Central African Republic", "Chad", "Chile", "China", "Colombia", "Comoros", "Congo (Brazzaville)", "Congo (Kinshasa)", "Costa Rica", "Croatia", "Cuba", "Cyprus", "Czechia",
66
- "Denmark", "Djibouti", "Dominica", "Dominican Republic", "Ecuador", "Egypt", "El Salvador", "Equatorial Guinea", "Eritrea", "Estonia", "Eswatini", "Ethiopia",
67
- "Fiji", "Finland", "France", "Gabon", "Gambia", "Georgia", "Germany", "Ghana", "Greece", "Grenada", "Guatemala", "Guinea", "Guinea-Bissau", "Guyana",
68
- "Haiti", "Honduras", "Hungary", "Iceland", "India", "Indonesia", "Iran", "Iraq", "Ireland", "Israel", "Italy", "Ivory Coast", "Jamaica", "Japan", "Jordan",
69
- "Kazakhstan", "Kenya", "Kiribati", "Kosovo", "Kuwait", "Kyrgyzstan", "Laos", "Latvia", "Lebanon", "Lesotho", "Liberia", "Libya", "Liechtenstein", "Lithuania", "Luxembourg",
70
- "Madagascar", "Malawi", "Malaysia", "Maldives", "Mali", "Malta", "Marshall Islands", "Mauritania", "Mauritius", "Mexico", "Micronesia", "Moldova", "Monaco", "Mongolia", "Montenegro", "Morocco", "Mozambique", "Myanmar",
71
- "Namibia", "Nauru", "Nepal", "Netherlands", "New Zealand", "Nicaragua", "Niger", "Nigeria", "North Korea", "North Macedonia", "Norway", "Oman",
72
- "Pakistan", "Palau", "Palestine", "Panama", "Papua New Guinea", "Paraguay", "Peru", "Philippines", "Poland", "Portugal", "Qatar", "Romania", "Russia", "Rwanda",
73
- "Saint Kitts and Nevis", "Saint Lucia", "Saint Vincent and the Grenadines", "Samoa", "San Marino", "Sao Tome and Principe", "Saudi Arabia", "Senegal", "Serbia", "Seychelles", "Sierra Leone", "Singapore", "Slovakia", "Slovenia", "Solomon Islands", "Somalia", "South Africa", "South Korea", "South Sudan", "Spain", "Sri Lanka", "Sudan", "Suriname", "Sweden", "Switzerland", "Syria",
74
- "Taiwan", "Tajikistan", "Tanzania", "Thailand", "Timor-Leste", "Togo", "Tonga", "Trinidad and Tobago", "Tunisia", "Turkey", "Turkmenistan", "Tuvalu",
75
- "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay", "Uzbekistan", "Vanuatu", "Vatican City", "Venezuela", "Vietnam",
76
- "Yemen", "Zambia", "Zimbabwe"
77
- ]
78
- COUNTRY_PATTERN = re.compile(r'\b(' + '|'.join(re.escape(c) for c in KNOWN_COUNTRIES) + r')\b', re.IGNORECASE)
79
-
80
- try:
81
- handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
82
- record = Entrez.read(handle)
83
- handle.close()
84
-
85
- gb_seq = None
86
- # Validate record structure: It should be a list with at least one element (a dict)
87
- if isinstance(record, list) and len(record) > 0:
88
- if isinstance(record[0], dict):
89
- gb_seq = record[0]
90
- else:
91
- print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
92
- else:
93
- print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
94
-
95
- # If gb_seq is still None, return defaults
96
- if gb_seq is None:
97
- return {"country": "unknown",
98
- "specific_location": "unknown",
99
- "ethnicity": "unknown",
100
- "sample_type": "unknown",
101
- "collection_date": "unknown",
102
- "isolate": "unknown",
103
- "title": "unknown",
104
- "doi": "unknown",
105
- "pubmed_id": None,
106
- "all_features": "unknown"}
107
-
108
-
109
- # If gb_seq is valid, proceed with extraction
110
- collection_date = gb_seq.get("GBSeq_create-date","unknown")
111
-
112
- references = gb_seq.get("GBSeq_references", [])
113
- for ref in references:
114
- if not pubmed_id:
115
- pubmed_id = ref.get("GBReference_pubmed",None)
116
- if title == "unknown":
117
- title = ref.get("GBReference_title","unknown")
118
- for xref in ref.get("GBReference_xref", []):
119
- if xref.get("GBXref_dbname") == "doi":
120
- doi = xref.get("GBXref_id")
121
- break
122
-
123
- features = gb_seq.get("GBSeq_feature-table", [])
124
-
125
- context_for_flagging = "" # Accumulate text for ancient/modern detection
126
- features_context = ""
127
- for feature in features:
128
- if feature.get("GBFeature_key") == "source":
129
- feature_context = ""
130
- qualifiers = feature.get("GBFeature_quals", [])
131
- found_country = "unknown"
132
- found_specific_location = "unknown"
133
- found_ethnicity = "unknown"
134
-
135
- temp_geo_loc_name = "unknown"
136
- temp_note_origin_locality = "unknown"
137
- temp_country_qual = "unknown"
138
- temp_locality_qual = "unknown"
139
- temp_collection_location_qual = "unknown"
140
- temp_isolation_source_qual = "unknown"
141
- temp_env_sample_qual = "unknown"
142
- temp_pop_qual = "unknown"
143
- temp_organism_qual = "unknown"
144
- temp_specimen_qual = "unknown"
145
- temp_strain_qual = "unknown"
146
-
147
- for qual in qualifiers:
148
- qual_name = qual.get("GBQualifier_name")
149
- qual_value = qual.get("GBQualifier_value")
150
- feature_context += qual_name + ": " + qual_value +"\n"
151
- if qual_name == "collection_date":
152
- collection_date = qual_value
153
- elif qual_name == "isolate":
154
- isolate = qual_value
155
- elif qual_name == "population":
156
- temp_pop_qual = qual_value
157
- elif qual_name == "organism":
158
- temp_organism_qual = qual_value
159
- elif qual_name == "specimen_voucher" or qual_name == "specimen":
160
- temp_specimen_qual = qual_value
161
- elif qual_name == "strain":
162
- temp_strain_qual = qual_value
163
- elif qual_name == "isolation_source":
164
- temp_isolation_source_qual = qual_value
165
- elif qual_name == "environmental_sample":
166
- temp_env_sample_qual = qual_value
167
-
168
- if qual_name == "geo_loc_name": temp_geo_loc_name = qual_value
169
- elif qual_name == "note":
170
- if qual_value.startswith("origin_locality:"):
171
- temp_note_origin_locality = qual_value
172
- context_for_flagging += qual_value + " " # Capture all notes for flagging
173
- elif qual_name == "country": temp_country_qual = qual_value
174
- elif qual_name == "locality": temp_locality_qual = qual_value
175
- elif qual_name == "collection_location": temp_collection_location_qual = qual_value
176
-
177
-
178
- # --- Aggregate all relevant info into context_for_flagging ---
179
- context_for_flagging += f" {isolate} {temp_isolation_source_qual} {temp_specimen_qual} {temp_strain_qual} {temp_organism_qual} {temp_geo_loc_name} {temp_collection_location_qual} {temp_env_sample_qual}"
180
- context_for_flagging = context_for_flagging.strip()
181
-
182
- # --- Determine final country and specific_location based on priority ---
183
- if temp_geo_loc_name != "unknown":
184
- parts = [p.strip() for p in temp_geo_loc_name.split(':')]
185
- if len(parts) > 1:
186
- found_specific_location = parts[-1]; found_country = parts[0]
187
- else: found_country = temp_geo_loc_name; found_specific_location = "unknown"
188
- elif temp_note_origin_locality != "unknown":
189
- match = re.search(r"origin_locality:\s*(.*)", temp_note_origin_locality, re.IGNORECASE)
190
- if match:
191
- location_string = match.group(1).strip()
192
- parts = [p.strip() for p in location_string.split(':')]
193
- if len(parts) > 1: found_country = parts[-1]; found_specific_location = parts[0]
194
- else: found_country = location_string; found_specific_location = "unknown"
195
- elif temp_locality_qual != "unknown":
196
- found_country_match = COUNTRY_PATTERN.search(temp_locality_qual)
197
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_locality_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
198
- else: found_specific_location = temp_locality_qual; found_country = "unknown"
199
- elif temp_collection_location_qual != "unknown":
200
- found_country_match = COUNTRY_PATTERN.search(temp_collection_location_qual)
201
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_collection_location_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
202
- else: found_specific_location = temp_collection_location_qual; found_country = "unknown"
203
- elif temp_isolation_source_qual != "unknown":
204
- found_country_match = COUNTRY_PATTERN.search(temp_isolation_source_qual)
205
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_isolation_source_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
206
- else: found_specific_location = temp_isolation_source_qual; found_country = "unknown"
207
- elif temp_env_sample_qual != "unknown":
208
- found_country_match = COUNTRY_PATTERN.search(temp_env_sample_qual)
209
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_env_sample_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
210
- else: found_specific_location = temp_env_sample_qual; found_country = "unknown"
211
- if found_country == "unknown" and temp_country_qual != "unknown":
212
- found_country_match = COUNTRY_PATTERN.search(temp_country_qual)
213
- if found_country_match: found_country = found_country_match.group(1)
214
-
215
- country = found_country
216
- specific_location = found_specific_location
217
- # --- Determine final ethnicity ---
218
- if temp_pop_qual != "unknown":
219
- found_ethnicity = temp_pop_qual
220
- elif isolate != "unknown" and re.fullmatch(r'[A-Za-z\s\-]+', isolate) and get_country_from_text(isolate) == "unknown":
221
- found_ethnicity = isolate
222
- elif context_for_flagging != "unknown": # Use the broader context for ethnicity patterns
223
- eth_match = re.search(r'(?:population|ethnicity|isolate source):\s*([A-Za-z\s\-]+)', context_for_flagging, re.IGNORECASE)
224
- if eth_match:
225
- found_ethnicity = eth_match.group(1).strip()
226
-
227
- ethnicity = found_ethnicity
228
-
229
- # --- Determine sample_type (ancient/modern) ---
230
- if context_for_flagging:
231
- sample_type, explain = detect_ancient_flag(context_for_flagging)
232
- features_context += feature_context + "\n"
233
- break
234
-
235
- if specific_location != "unknown" and specific_location.lower() == country.lower():
236
- specific_location = "unknown"
237
- if not features_context: features_context = "unknown"
238
- return {"country": country.lower(),
239
- "specific_location": specific_location.lower(),
240
- "ethnicity": ethnicity.lower(),
241
- "sample_type": sample_type.lower(),
242
- "collection_date": collection_date,
243
- "isolate": isolate,
244
- "title": title,
245
- "doi": doi,
246
- "pubmed_id": pubmed_id,
247
- "all_features": features_context}
248
-
249
- except:
250
- print(f"Error fetching NCBI data for {accession_number}")
251
- return {"country": "unknown",
252
- "specific_location": "unknown",
253
- "ethnicity": "unknown",
254
- "sample_type": "unknown",
255
- "collection_date": "unknown",
256
- "isolate": "unknown",
257
- "title": "unknown",
258
- "doi": "unknown",
259
- "pubmed_id": None,
260
- "all_features": "unknown"}
261
-
262
- # --- Helper function for country matching (re-defined from main code to be self-contained) ---
263
- _country_keywords = {
264
- "thailand": "Thailand", "laos": "Laos", "cambodia": "Cambodia", "myanmar": "Myanmar",
265
- "philippines": "Philippines", "indonesia": "Indonesia", "malaysia": "Malaysia",
266
- "china": "China", "chinese": "China", "india": "India", "taiwan": "Taiwan",
267
- "vietnam": "Vietnam", "russia": "Russia", "siberia": "Russia", "nepal": "Nepal",
268
- "japan": "Japan", "sumatra": "Indonesia", "borneu": "Indonesia",
269
- "yunnan": "China", "tibet": "China", "northern mindanao": "Philippines",
270
- "west malaysia": "Malaysia", "north thailand": "Thailand", "central thailand": "Thailand",
271
- "northeast thailand": "Thailand", "east myanmar": "Myanmar", "west thailand": "Thailand",
272
- "central india": "India", "east india": "India", "northeast india": "India",
273
- "south sibera": "Russia", "mongolia": "China", "beijing": "China", "south korea": "South Korea",
274
- "north asia": "unknown", "southeast asia": "unknown", "east asia": "unknown"
275
- }
276
-
277
- def get_country_from_text(text):
278
- text_lower = text.lower()
279
- for keyword, country in _country_keywords.items():
280
- if keyword in text_lower:
281
- return country
282
- return "unknown"
283
- # The result will be seen as manualLink for the function get_paper_text
284
- # def search_google_custom(query, max_results=3):
285
- # # query should be the title from ncbi or paper/source title
286
- # GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
287
- # GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
288
- # endpoint = os.environ["SEARCH_ENDPOINT"]
289
- # params = {
290
- # "key": GOOGLE_CSE_API_KEY,
291
- # "cx": GOOGLE_CSE_CX,
292
- # "q": query,
293
- # "num": max_results
294
- # }
295
- # try:
296
- # response = requests.get(endpoint, params=params)
297
- # if response.status_code == 429:
298
- # print("Rate limit hit. Try again later.")
299
- # return []
300
- # response.raise_for_status()
301
- # data = response.json().get("items", [])
302
- # return [item.get("link") for item in data if item.get("link")]
303
- # except Exception as e:
304
- # print("Google CSE error:", e)
305
- # return []
306
-
307
- def search_google_custom(query, max_results=3):
308
- # query should be the title from ncbi or paper/source title
309
- GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
310
- GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
311
- endpoint = os.environ["SEARCH_ENDPOINT"]
312
- params = {
313
- "key": GOOGLE_CSE_API_KEY,
314
- "cx": GOOGLE_CSE_CX,
315
- "q": query,
316
- "num": max_results
317
- }
318
- try:
319
- response = requests.get(endpoint, params=params)
320
- if response.status_code == 429:
321
- print("Rate limit hit. Try again later.")
322
- print("try with back up account")
323
- try:
324
- return search_google_custom_backup(query, max_results)
325
- except:
326
- return []
327
- response.raise_for_status()
328
- data = response.json().get("items", [])
329
- return [item.get("link") for item in data if item.get("link")]
330
- except Exception as e:
331
- print("Google CSE error:", e)
332
- return []
333
-
334
- def search_google_custom_backup(query, max_results=3):
335
- # query should be the title from ncbi or paper/source title
336
- GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY_BACKUP"]
337
- GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX_BACKUP"]
338
- endpoint = os.environ["SEARCH_ENDPOINT"]
339
- params = {
340
- "key": GOOGLE_CSE_API_KEY,
341
- "cx": GOOGLE_CSE_CX,
342
- "q": query,
343
- "num": max_results
344
- }
345
- try:
346
- response = requests.get(endpoint, params=params)
347
- if response.status_code == 429:
348
- print("Rate limit hit. Try again later.")
349
- return []
350
- response.raise_for_status()
351
- data = response.json().get("items", [])
352
- return [item.get("link") for item in data if item.get("link")]
353
- except Exception as e:
354
- print("Google CSE error:", e)
355
- return []
356
- # Step 3: Extract Text: Get the paper (html text), sup. materials (pdf, doc, excel) and do text-preprocessing
357
- # Step 3.1: Extract Text
358
- # sub: download excel file
359
- def download_excel_file(url, save_path="temp.xlsx"):
360
- if "view.officeapps.live.com" in url:
361
- parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
362
- real_url = urllib.parse.unquote(parsed_url["src"][0])
363
- response = requests.get(real_url)
364
- with open(save_path, "wb") as f:
365
- f.write(response.content)
366
- return save_path
367
- elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
368
- response = requests.get(url)
369
- response.raise_for_status() # Raises error if download fails
370
- with open(save_path, "wb") as f:
371
- f.write(response.content)
372
- return save_path
373
- else:
374
- print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
375
- return url
376
- def get_paper_text(doi,id,manualLinks=None):
377
- # create the temporary folder to contain the texts
378
- folder_path = Path("data/"+str(id))
379
- if not folder_path.exists():
380
- cmd = f'mkdir data/{id}'
381
- result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
382
- print("data/"+str(id) +" created.")
383
- else:
384
- print("data/"+str(id) +" already exists.")
385
- saveLinkFolder = "data/"+id
386
-
387
- link = 'https://doi.org/' + doi
388
- '''textsToExtract = { "doiLink":"paperText"
389
- "file1.pdf":"text1",
390
- "file2.doc":"text2",
391
- "file3.xlsx":excelText3'''
392
- textsToExtract = {}
393
- # get the file to create listOfFile for each id
394
- html = extractHTML.HTML("",link)
395
- jsonSM = html.getSupMaterial()
396
- text = ""
397
- links = [link] + sum((jsonSM[key] for key in jsonSM),[])
398
- if manualLinks != None:
399
- links += manualLinks
400
- for l in links:
401
- # get the main paper
402
- name = l.split("/")[-1]
403
- file_path = folder_path / name
404
- if l == link:
405
- text = html.getListSection()
406
- textsToExtract[link] = text
407
- elif l.endswith(".pdf"):
408
- if file_path.is_file():
409
- l = saveLinkFolder + "/" + name
410
- print("File exists.")
411
- p = pdf.PDF(l,saveLinkFolder,doi)
412
- f = p.openPDFFile()
413
- pdf_path = saveLinkFolder + "/" + l.split("/")[-1]
414
- doc = fitz.open(pdf_path)
415
- text = "\n".join([page.get_text() for page in doc])
416
- textsToExtract[l] = text
417
- elif l.endswith(".doc") or l.endswith(".docx"):
418
- d = wordDoc.wordDoc(l,saveLinkFolder)
419
- text = d.extractTextByPage()
420
- textsToExtract[l] = text
421
- elif l.split(".")[-1].lower() in "xlsx":
422
- wc = word2vec.word2Vec()
423
- # download excel file if it not downloaded yet
424
- savePath = saveLinkFolder +"/"+ l.split("/")[-1]
425
- excelPath = download_excel_file(l, savePath)
426
- corpus = wc.tableTransformToCorpusText([],excelPath)
427
- text = ''
428
- for c in corpus:
429
- para = corpus[c]
430
- for words in para:
431
- text += " ".join(words)
432
- textsToExtract[l] = text
433
- # delete folder after finishing getting text
434
- #cmd = f'rm -r data/{id}'
435
- #result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
436
- return textsToExtract
437
- # Step 3.2: Extract context
438
- def extract_context(text, keyword, window=500):
439
- # firstly try accession number
440
- idx = text.find(keyword)
441
- if idx == -1:
442
- return "Sample ID not found."
443
- return text[max(0, idx-window): idx+window]
444
- def extract_relevant_paragraphs(text, accession, keep_if=None, isolate=None):
445
- if keep_if is None:
446
- keep_if = ["sample", "method", "mtdna", "sequence", "collected", "dataset", "supplementary", "table"]
447
-
448
- outputs = ""
449
- text = text.lower()
450
-
451
- # If isolate is provided, prioritize paragraphs that mention it
452
- # If isolate is provided, prioritize paragraphs that mention it
453
- if accession and accession.lower() in text:
454
- if extract_context(text, accession.lower(), window=700) != "Sample ID not found.":
455
- outputs += extract_context(text, accession.lower(), window=700)
456
- if isolate and isolate.lower() in text:
457
- if extract_context(text, isolate.lower(), window=700) != "Sample ID not found.":
458
- outputs += extract_context(text, isolate.lower(), window=700)
459
- for keyword in keep_if:
460
- para = extract_context(text, keyword)
461
- if para and para not in outputs:
462
- outputs += para + "\n"
463
- return outputs
464
- # Step 4: Classification for now (demo purposes)
465
- # 4.1: Using a HuggingFace model (question-answering)
466
- def infer_fromQAModel(context, question="Where is the mtDNA sample from?"):
467
- try:
468
- qa = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
469
- result = qa({"context": context, "question": question})
470
- return result.get("answer", "Unknown")
471
- except Exception as e:
472
- return f"Error: {str(e)}"
473
-
474
- # 4.2: Infer from haplogroup
475
- # Load pre-trained spaCy model for NER
476
- try:
477
- nlp = spacy.load("en_core_web_sm")
478
- except OSError:
479
- download("en_core_web_sm")
480
- nlp = spacy.load("en_core_web_sm")
481
-
482
- # Define the haplogroup-to-region mapping (simple rule-based)
483
- import csv
484
-
485
- def load_haplogroup_mapping(csv_path):
486
- mapping = {}
487
- with open(csv_path) as f:
488
- reader = csv.DictReader(f)
489
- for row in reader:
490
- mapping[row["haplogroup"]] = [row["region"],row["source"]]
491
- return mapping
492
-
493
- # Function to extract haplogroup from the text
494
- def extract_haplogroup(text):
495
- match = re.search(r'\bhaplogroup\s+([A-Z][0-9a-z]*)\b', text)
496
- if match:
497
- submatch = re.match(r'^[A-Z][0-9]*', match.group(1))
498
- if submatch:
499
- return submatch.group(0)
500
- else:
501
- return match.group(1) # fallback
502
- fallback = re.search(r'\b([A-Z][0-9a-z]{1,5})\b', text)
503
- if fallback:
504
- return fallback.group(1)
505
- return None
506
-
507
-
508
- # Function to extract location based on NER
509
- def extract_location(text):
510
- doc = nlp(text)
511
- locations = []
512
- for ent in doc.ents:
513
- if ent.label_ == "GPE": # GPE = Geopolitical Entity (location)
514
- locations.append(ent.text)
515
- return locations
516
-
517
- # Function to infer location from haplogroup
518
- def infer_location_from_haplogroup(haplogroup):
519
- haplo_map = load_haplogroup_mapping("data/haplogroup_regions_extended.csv")
520
- return haplo_map.get(haplogroup, ["Unknown","Unknown"])
521
-
522
- # Function to classify the mtDNA sample
523
- def classify_mtDNA_sample_from_haplo(text):
524
- # Extract haplogroup
525
- haplogroup = extract_haplogroup(text)
526
- # Extract location based on NER
527
- locations = extract_location(text)
528
- # Infer location based on haplogroup
529
- inferred_location, sourceHaplo = infer_location_from_haplogroup(haplogroup)[0],infer_location_from_haplogroup(haplogroup)[1]
530
- return {
531
- "source":sourceHaplo,
532
- "locations_found_in_context": locations,
533
- "haplogroup": haplogroup,
534
- "inferred_location": inferred_location
535
-
536
- }
537
- # 4.3 Get from available NCBI
538
- def infer_location_fromNCBI(accession):
539
- try:
540
- handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
541
- text = handle.read()
542
- handle.close()
543
- match = re.search(r'/(geo_loc_name|country|location)\s*=\s*"([^"]+)"', text)
544
- if match:
545
- return match.group(2), match.group(0) # This is the value like "Brunei"
546
- return "Not found", "Not found"
547
-
548
- except Exception as e:
549
- print("❌ Entrez error:", e)
550
- return "Not found", "Not found"
551
-
552
- ### ANCIENT/MODERN FLAG
553
- from Bio import Entrez
554
- import re
555
-
556
- def flag_ancient_modern(accession, textsToExtract, isolate=None):
557
- """
558
- Try to classify a sample as Ancient or Modern using:
559
- 1. NCBI accession (if available)
560
- 2. Supplementary text or context fallback
561
- """
562
- context = ""
563
- label, explain = "", ""
564
-
565
- try:
566
- # Check if we can fetch metadata from NCBI using the accession
567
- handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
568
- text = handle.read()
569
- handle.close()
570
-
571
- isolate_source = re.search(r'/(isolation_source)\s*=\s*"([^"]+)"', text)
572
- if isolate_source:
573
- context += isolate_source.group(0) + " "
574
-
575
- specimen = re.search(r'/(specimen|specimen_voucher)\s*=\s*"([^"]+)"', text)
576
- if specimen:
577
- context += specimen.group(0) + " "
578
-
579
- if context.strip():
580
- label, explain = detect_ancient_flag(context)
581
- if label!="Unknown":
582
- return label, explain + " from NCBI\n(" + context + ")"
583
-
584
- # If no useful NCBI metadata, check supplementary texts
585
- if textsToExtract:
586
- labels = {"modern": [0, ""], "ancient": [0, ""], "unknown": 0}
587
-
588
- for source in textsToExtract:
589
- text_block = textsToExtract[source]
590
- context = extract_relevant_paragraphs(text_block, accession, isolate=isolate) # Reduce to informative paragraph(s)
591
- label, explain = detect_ancient_flag(context)
592
-
593
- if label == "Ancient":
594
- labels["ancient"][0] += 1
595
- labels["ancient"][1] += f"{source}:\n{explain}\n\n"
596
- elif label == "Modern":
597
- labels["modern"][0] += 1
598
- labels["modern"][1] += f"{source}:\n{explain}\n\n"
599
- else:
600
- labels["unknown"] += 1
601
-
602
- if max(labels["modern"][0],labels["ancient"][0]) > 0:
603
- if labels["modern"][0] > labels["ancient"][0]:
604
- return "Modern", labels["modern"][1]
605
- else:
606
- return "Ancient", labels["ancient"][1]
607
- else:
608
- return "Unknown", "No strong keywords detected"
609
- else:
610
- print("No DOI or PubMed ID available for inference.")
611
- return "", ""
612
-
613
- except Exception as e:
614
- print("Error:", e)
615
- return "", ""
616
-
617
-
618
- def detect_ancient_flag(context_snippet):
619
- context = context_snippet.lower()
620
-
621
- ancient_keywords = [
622
- "ancient", "archaeological", "prehistoric", "neolithic", "mesolithic", "paleolithic",
623
- "bronze age", "iron age", "burial", "tomb", "skeleton", "14c", "radiocarbon", "carbon dating",
624
- "postmortem damage", "udg treatment", "adna", "degradation", "site", "excavation",
625
- "archaeological context", "temporal transect", "population replacement", "cal bp", "calbp", "carbon dated"
626
- ]
627
-
628
- modern_keywords = [
629
- "modern", "hospital", "clinical", "consent","blood","buccal","unrelated", "blood sample","buccal sample","informed consent", "donor", "healthy", "patient",
630
- "genotyping", "screening", "medical", "cohort", "sequencing facility", "ethics approval",
631
- "we analysed", "we analyzed", "dataset includes", "new sequences", "published data",
632
- "control cohort", "sink population", "genbank accession", "sequenced", "pipeline",
633
- "bioinformatic analysis", "samples from", "population genetics", "genome-wide data", "imr collection"
634
- ]
635
-
636
- ancient_hits = [k for k in ancient_keywords if k in context]
637
- modern_hits = [k for k in modern_keywords if k in context]
638
-
639
- if ancient_hits and not modern_hits:
640
- return "Ancient", f"Flagged as ancient due to keywords: {', '.join(ancient_hits)}"
641
- elif modern_hits and not ancient_hits:
642
- return "Modern", f"Flagged as modern due to keywords: {', '.join(modern_hits)}"
643
- elif ancient_hits and modern_hits:
644
- if len(ancient_hits) >= len(modern_hits):
645
- return "Ancient", f"Mixed context, leaning ancient due to: {', '.join(ancient_hits)}"
646
- else:
647
- return "Modern", f"Mixed context, leaning modern due to: {', '.join(modern_hits)}"
648
-
649
- # Fallback to QA
650
- answer = infer_fromQAModel(context, question="Are the mtDNA samples ancient or modern? Explain why.")
651
- if answer.startswith("Error"):
652
- return "Unknown", answer
653
- if "ancient" in answer.lower():
654
- return "Ancient", f"Leaning ancient based on QA: {answer}"
655
- elif "modern" in answer.lower():
656
- return "Modern", f"Leaning modern based on QA: {answer}"
657
- else:
658
- return "Unknown", f"No strong keywords or QA clues. QA said: {answer}"
659
-
660
- # STEP 5: Main pipeline: accession -> 1. get pubmed id and isolate -> 2. get doi -> 3. get text -> 4. prediction -> 5. output: inferred location + explanation + confidence score
661
- def classify_sample_location(accession):
662
- outputs = {}
663
- keyword, context, location, qa_result, haplo_result = "", "", "", "", ""
664
- # Step 1: get pubmed id and isolate
665
- pubmedID, isolate = get_info_from_accession(accession)
666
- '''if not pubmedID:
667
- return {"error": f"Could not retrieve PubMed ID for accession {accession}"}'''
668
- if not isolate:
669
- isolate = "UNKNOWN_ISOLATE"
670
- # Step 2: get doi
671
- doi = get_doi_from_pubmed_id(pubmedID)
672
- '''if not doi:
673
- return {"error": "DOI not found for this accession. Cannot fetch paper or context."}'''
674
- # Step 3: get text
675
- '''textsToExtract = { "doiLink":"paperText"
676
- "file1.pdf":"text1",
677
- "file2.doc":"text2",
678
- "file3.xlsx":excelText3'''
679
- if doi and pubmedID:
680
- textsToExtract = get_paper_text(doi,pubmedID)
681
- else: textsToExtract = {}
682
- '''if not textsToExtract:
683
- return {"error": f"No texts extracted for DOI {doi}"}'''
684
- if isolate not in [None, "UNKNOWN_ISOLATE"]:
685
- label, explain = flag_ancient_modern(accession,textsToExtract,isolate)
686
- else:
687
- label, explain = flag_ancient_modern(accession,textsToExtract)
688
- # Step 4: prediction
689
- outputs[accession] = {}
690
- outputs[isolate] = {}
691
- # 4.0 Infer from NCBI
692
- location, outputNCBI = infer_location_fromNCBI(accession)
693
- NCBI_result = {
694
- "source": "NCBI",
695
- "sample_id": accession,
696
- "predicted_location": location,
697
- "context_snippet": outputNCBI}
698
- outputs[accession]["NCBI"]= {"NCBI": NCBI_result}
699
- if textsToExtract:
700
- long_text = ""
701
- for key in textsToExtract:
702
- text = textsToExtract[key]
703
- # try accession number first
704
- outputs[accession][key] = {}
705
- keyword = accession
706
- context = extract_context(text, keyword, window=500)
707
- # 4.1: Using a HuggingFace model (question-answering)
708
- location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
709
- qa_result = {
710
- "source": key,
711
- "sample_id": keyword,
712
- "predicted_location": location,
713
- "context_snippet": context
714
- }
715
- outputs[keyword][key]["QAModel"] = qa_result
716
- # 4.2: Infer from haplogroup
717
- haplo_result = classify_mtDNA_sample_from_haplo(context)
718
- outputs[keyword][key]["haplogroup"] = haplo_result
719
- # try isolate
720
- keyword = isolate
721
- outputs[isolate][key] = {}
722
- context = extract_context(text, keyword, window=500)
723
- # 4.1.1: Using a HuggingFace model (question-answering)
724
- location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
725
- qa_result = {
726
- "source": key,
727
- "sample_id": keyword,
728
- "predicted_location": location,
729
- "context_snippet": context
730
- }
731
- outputs[keyword][key]["QAModel"] = qa_result
732
- # 4.2.1: Infer from haplogroup
733
- haplo_result = classify_mtDNA_sample_from_haplo(context)
734
- outputs[keyword][key]["haplogroup"] = haplo_result
735
- # add long text
736
- long_text += text + ". \n"
737
- # 4.3: UpgradeClassify
738
- # try sample_id as accession number
739
- sample_id = accession
740
- if sample_id:
741
- filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
742
- locations = infer_location_for_sample(sample_id.upper(), filtered_context)
743
- if locations!="No clear location found in top matches":
744
- outputs[sample_id]["upgradeClassifier"] = {}
745
- outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
746
- "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
747
- "sample_id": sample_id,
748
- "predicted_location": ", ".join(locations),
749
- "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
750
- }
751
- # try sample_id as isolate name
752
- sample_id = isolate
753
- if sample_id:
754
- filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
755
- locations = infer_location_for_sample(sample_id.upper(), filtered_context)
756
- if locations!="No clear location found in top matches":
757
- outputs[sample_id]["upgradeClassifier"] = {}
758
- outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
759
- "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
760
- "sample_id": sample_id,
761
- "predicted_location": ", ".join(locations),
762
- "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
763
- }
 
 
 
 
 
764
  return outputs, label, explain
 
1
+ # mtDNA Location Classifier MVP (Google Colab)
2
+ # Accepts accession number → Fetches PubMed ID + isolate name → Gets abstract → Predicts location
3
+ import os
4
+ #import streamlit as st
5
+ import subprocess
6
+ import re
7
+ from Bio import Entrez
8
+ import fitz
9
+ import spacy
10
+ from spacy.cli import download
11
+ from NER.PDF import pdf
12
+ from NER.WordDoc import wordDoc
13
+ from NER.html import extractHTML
14
+ from NER.word2Vec import word2vec
15
+ from transformers import pipeline
16
+ import urllib.parse, requests
17
+ from pathlib import Path
18
+ from upgradeClassify import filter_context_for_sample, infer_location_for_sample
19
+ import model
20
+ # Set your email (required by NCBI Entrez)
21
+ #Entrez.email = "your-email@example.com"
22
+ import nltk
23
+
24
+ nltk.download("stopwords")
25
+ nltk.download("punkt")
26
+ nltk.download('punkt_tab')
27
+ # Step 1: Get PubMed ID from Accession using EDirect
28
+ from Bio import Entrez, Medline
29
+ import re
30
+
31
+ Entrez.email = "your_email@example.com"
32
+
33
+ # --- Helper Functions (Re-organized and Upgraded) ---
34
+
35
+ def fetch_ncbi_metadata(accession_number):
36
+ """
37
+ Fetches metadata directly from NCBI GenBank using Entrez.
38
+ Includes robust error handling and improved field extraction.
39
+ Prioritizes location extraction from geo_loc_name, then notes, then other qualifiers.
40
+ Also attempts to extract ethnicity and sample_type (ancient/modern).
41
+
42
+ Args:
43
+ accession_number (str): The NCBI accession number (e.g., "ON792208").
44
+
45
+ Returns:
46
+ dict: A dictionary containing 'country', 'specific_location', 'ethnicity',
47
+ 'sample_type', 'collection_date', 'isolate', 'title', 'doi', 'pubmed_id'.
48
+ """
49
+ Entrez.email = "your.email@example.com" # Required by NCBI, REPLACE WITH YOUR EMAIL
50
+
51
+ country = "unknown"
52
+ specific_location = "unknown"
53
+ ethnicity = "unknown"
54
+ sample_type = "unknown"
55
+ collection_date = "unknown"
56
+ isolate = "unknown"
57
+ title = "unknown"
58
+ doi = "unknown"
59
+ pubmed_id = None
60
+ all_feature = "unknown"
61
+
62
+ KNOWN_COUNTRIES = [
63
+ "Afghanistan", "Albania", "Algeria", "Andorra", "Angola", "Antigua and Barbuda", "Argentina", "Armenia", "Australia", "Austria", "Azerbaijan",
64
+ "Bahamas", "Bahrain", "Bangladesh", "Barbados", "Belarus", "Belgium", "Belize", "Benin", "Bhutan", "Bolivia", "Bosnia and Herzegovina", "Botswana", "Brazil", "Brunei", "Bulgaria", "Burkina Faso", "Burundi",
65
+ "Cabo Verde", "Cambodia", "Cameroon", "Canada", "Central African Republic", "Chad", "Chile", "China", "Colombia", "Comoros", "Congo (Brazzaville)", "Congo (Kinshasa)", "Costa Rica", "Croatia", "Cuba", "Cyprus", "Czechia",
66
+ "Denmark", "Djibouti", "Dominica", "Dominican Republic", "Ecuador", "Egypt", "El Salvador", "Equatorial Guinea", "Eritrea", "Estonia", "Eswatini", "Ethiopia",
67
+ "Fiji", "Finland", "France", "Gabon", "Gambia", "Georgia", "Germany", "Ghana", "Greece", "Grenada", "Guatemala", "Guinea", "Guinea-Bissau", "Guyana",
68
+ "Haiti", "Honduras", "Hungary", "Iceland", "India", "Indonesia", "Iran", "Iraq", "Ireland", "Israel", "Italy", "Ivory Coast", "Jamaica", "Japan", "Jordan",
69
+ "Kazakhstan", "Kenya", "Kiribati", "Kosovo", "Kuwait", "Kyrgyzstan", "Laos", "Latvia", "Lebanon", "Lesotho", "Liberia", "Libya", "Liechtenstein", "Lithuania", "Luxembourg",
70
+ "Madagascar", "Malawi", "Malaysia", "Maldives", "Mali", "Malta", "Marshall Islands", "Mauritania", "Mauritius", "Mexico", "Micronesia", "Moldova", "Monaco", "Mongolia", "Montenegro", "Morocco", "Mozambique", "Myanmar",
71
+ "Namibia", "Nauru", "Nepal", "Netherlands", "New Zealand", "Nicaragua", "Niger", "Nigeria", "North Korea", "North Macedonia", "Norway", "Oman",
72
+ "Pakistan", "Palau", "Palestine", "Panama", "Papua New Guinea", "Paraguay", "Peru", "Philippines", "Poland", "Portugal", "Qatar", "Romania", "Russia", "Rwanda",
73
+ "Saint Kitts and Nevis", "Saint Lucia", "Saint Vincent and the Grenadines", "Samoa", "San Marino", "Sao Tome and Principe", "Saudi Arabia", "Senegal", "Serbia", "Seychelles", "Sierra Leone", "Singapore", "Slovakia", "Slovenia", "Solomon Islands", "Somalia", "South Africa", "South Korea", "South Sudan", "Spain", "Sri Lanka", "Sudan", "Suriname", "Sweden", "Switzerland", "Syria",
74
+ "Taiwan", "Tajikistan", "Tanzania", "Thailand", "Timor-Leste", "Togo", "Tonga", "Trinidad and Tobago", "Tunisia", "Turkey", "Turkmenistan", "Tuvalu",
75
+ "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay", "Uzbekistan", "Vanuatu", "Vatican City", "Venezuela", "Vietnam",
76
+ "Yemen", "Zambia", "Zimbabwe"
77
+ ]
78
+ COUNTRY_PATTERN = re.compile(r'\b(' + '|'.join(re.escape(c) for c in KNOWN_COUNTRIES) + r')\b', re.IGNORECASE)
79
+
80
+ try:
81
+ handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
82
+ record = Entrez.read(handle)
83
+ handle.close()
84
+
85
+ gb_seq = None
86
+ # Validate record structure: It should be a list with at least one element (a dict)
87
+ if isinstance(record, list) and len(record) > 0:
88
+ if isinstance(record[0], dict):
89
+ gb_seq = record[0]
90
+ else:
91
+ print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
92
+ else:
93
+ print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
94
+
95
+ # If gb_seq is still None, return defaults
96
+ if gb_seq is None:
97
+ return {"country": "unknown",
98
+ "specific_location": "unknown",
99
+ "ethnicity": "unknown",
100
+ "sample_type": "unknown",
101
+ "collection_date": "unknown",
102
+ "isolate": "unknown",
103
+ "title": "unknown",
104
+ "doi": "unknown",
105
+ "pubmed_id": None,
106
+ "all_features": "unknown"}
107
+
108
+
109
+ # If gb_seq is valid, proceed with extraction
110
+ collection_date = gb_seq.get("GBSeq_create-date","unknown")
111
+
112
+ references = gb_seq.get("GBSeq_references", [])
113
+ for ref in references:
114
+ if not pubmed_id:
115
+ pubmed_id = ref.get("GBReference_pubmed",None)
116
+ if title == "unknown":
117
+ title = ref.get("GBReference_title","unknown")
118
+ for xref in ref.get("GBReference_xref", []):
119
+ if xref.get("GBXref_dbname") == "doi":
120
+ doi = xref.get("GBXref_id")
121
+ break
122
+
123
+ features = gb_seq.get("GBSeq_feature-table", [])
124
+
125
+ context_for_flagging = "" # Accumulate text for ancient/modern detection
126
+ features_context = ""
127
+ for feature in features:
128
+ if feature.get("GBFeature_key") == "source":
129
+ feature_context = ""
130
+ qualifiers = feature.get("GBFeature_quals", [])
131
+ found_country = "unknown"
132
+ found_specific_location = "unknown"
133
+ found_ethnicity = "unknown"
134
+
135
+ temp_geo_loc_name = "unknown"
136
+ temp_note_origin_locality = "unknown"
137
+ temp_country_qual = "unknown"
138
+ temp_locality_qual = "unknown"
139
+ temp_collection_location_qual = "unknown"
140
+ temp_isolation_source_qual = "unknown"
141
+ temp_env_sample_qual = "unknown"
142
+ temp_pop_qual = "unknown"
143
+ temp_organism_qual = "unknown"
144
+ temp_specimen_qual = "unknown"
145
+ temp_strain_qual = "unknown"
146
+
147
+ for qual in qualifiers:
148
+ qual_name = qual.get("GBQualifier_name")
149
+ qual_value = qual.get("GBQualifier_value")
150
+ feature_context += qual_name + ": " + qual_value +"\n"
151
+ if qual_name == "collection_date":
152
+ collection_date = qual_value
153
+ elif qual_name == "isolate":
154
+ isolate = qual_value
155
+ elif qual_name == "population":
156
+ temp_pop_qual = qual_value
157
+ elif qual_name == "organism":
158
+ temp_organism_qual = qual_value
159
+ elif qual_name == "specimen_voucher" or qual_name == "specimen":
160
+ temp_specimen_qual = qual_value
161
+ elif qual_name == "strain":
162
+ temp_strain_qual = qual_value
163
+ elif qual_name == "isolation_source":
164
+ temp_isolation_source_qual = qual_value
165
+ elif qual_name == "environmental_sample":
166
+ temp_env_sample_qual = qual_value
167
+
168
+ if qual_name == "geo_loc_name": temp_geo_loc_name = qual_value
169
+ elif qual_name == "note":
170
+ if qual_value.startswith("origin_locality:"):
171
+ temp_note_origin_locality = qual_value
172
+ context_for_flagging += qual_value + " " # Capture all notes for flagging
173
+ elif qual_name == "country": temp_country_qual = qual_value
174
+ elif qual_name == "locality": temp_locality_qual = qual_value
175
+ elif qual_name == "collection_location": temp_collection_location_qual = qual_value
176
+
177
+
178
+ # --- Aggregate all relevant info into context_for_flagging ---
179
+ context_for_flagging += f" {isolate} {temp_isolation_source_qual} {temp_specimen_qual} {temp_strain_qual} {temp_organism_qual} {temp_geo_loc_name} {temp_collection_location_qual} {temp_env_sample_qual}"
180
+ context_for_flagging = context_for_flagging.strip()
181
+
182
+ # --- Determine final country and specific_location based on priority ---
183
+ if temp_geo_loc_name != "unknown":
184
+ parts = [p.strip() for p in temp_geo_loc_name.split(':')]
185
+ if len(parts) > 1:
186
+ found_specific_location = parts[-1]; found_country = parts[0]
187
+ else: found_country = temp_geo_loc_name; found_specific_location = "unknown"
188
+ elif temp_note_origin_locality != "unknown":
189
+ match = re.search(r"origin_locality:\s*(.*)", temp_note_origin_locality, re.IGNORECASE)
190
+ if match:
191
+ location_string = match.group(1).strip()
192
+ parts = [p.strip() for p in location_string.split(':')]
193
+ if len(parts) > 1:
194
+ #found_country = parts[-1]; found_specific_location = parts[0]
195
+ found_country = model.get_country_from_text(temp_note_origin_locality.lower())
196
+ if found_country == "unknown":
197
+ found_country = parts[0];
198
+ found_specific_location = parts[-1]
199
+ else: found_country = location_string; found_specific_location = "unknown"
200
+ elif temp_locality_qual != "unknown":
201
+ found_country_match = COUNTRY_PATTERN.search(temp_locality_qual)
202
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_locality_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
203
+ else: found_specific_location = temp_locality_qual; found_country = "unknown"
204
+ elif temp_collection_location_qual != "unknown":
205
+ found_country_match = COUNTRY_PATTERN.search(temp_collection_location_qual)
206
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_collection_location_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
207
+ else: found_specific_location = temp_collection_location_qual; found_country = "unknown"
208
+ elif temp_isolation_source_qual != "unknown":
209
+ found_country_match = COUNTRY_PATTERN.search(temp_isolation_source_qual)
210
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_isolation_source_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
211
+ else: found_specific_location = temp_isolation_source_qual; found_country = "unknown"
212
+ elif temp_env_sample_qual != "unknown":
213
+ found_country_match = COUNTRY_PATTERN.search(temp_env_sample_qual)
214
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_env_sample_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
215
+ else: found_specific_location = temp_env_sample_qual; found_country = "unknown"
216
+ if found_country == "unknown" and temp_country_qual != "unknown":
217
+ found_country_match = COUNTRY_PATTERN.search(temp_country_qual)
218
+ if found_country_match: found_country = found_country_match.group(1)
219
+
220
+ country = found_country
221
+ specific_location = found_specific_location
222
+ # --- Determine final ethnicity ---
223
+ if temp_pop_qual != "unknown":
224
+ found_ethnicity = temp_pop_qual
225
+ elif isolate != "unknown" and re.fullmatch(r'[A-Za-z\s\-]+', isolate) and get_country_from_text(isolate) == "unknown":
226
+ found_ethnicity = isolate
227
+ elif context_for_flagging != "unknown": # Use the broader context for ethnicity patterns
228
+ eth_match = re.search(r'(?:population|ethnicity|isolate source):\s*([A-Za-z\s\-]+)', context_for_flagging, re.IGNORECASE)
229
+ if eth_match:
230
+ found_ethnicity = eth_match.group(1).strip()
231
+
232
+ ethnicity = found_ethnicity
233
+
234
+ # --- Determine sample_type (ancient/modern) ---
235
+ if context_for_flagging:
236
+ sample_type, explain = detect_ancient_flag(context_for_flagging)
237
+ features_context += feature_context + "\n"
238
+ break
239
+
240
+ if specific_location != "unknown" and specific_location.lower() == country.lower():
241
+ specific_location = "unknown"
242
+ if not features_context: features_context = "unknown"
243
+ return {"country": country.lower(),
244
+ "specific_location": specific_location.lower(),
245
+ "ethnicity": ethnicity.lower(),
246
+ "sample_type": sample_type.lower(),
247
+ "collection_date": collection_date,
248
+ "isolate": isolate,
249
+ "title": title,
250
+ "doi": doi,
251
+ "pubmed_id": pubmed_id,
252
+ "all_features": features_context}
253
+
254
+ except:
255
+ print(f"Error fetching NCBI data for {accession_number}")
256
+ return {"country": "unknown",
257
+ "specific_location": "unknown",
258
+ "ethnicity": "unknown",
259
+ "sample_type": "unknown",
260
+ "collection_date": "unknown",
261
+ "isolate": "unknown",
262
+ "title": "unknown",
263
+ "doi": "unknown",
264
+ "pubmed_id": None,
265
+ "all_features": "unknown"}
266
+
267
+ # --- Helper function for country matching (re-defined from main code to be self-contained) ---
268
+ _country_keywords = {
269
+ "thailand": "Thailand", "laos": "Laos", "cambodia": "Cambodia", "myanmar": "Myanmar",
270
+ "philippines": "Philippines", "indonesia": "Indonesia", "malaysia": "Malaysia",
271
+ "china": "China", "chinese": "China", "india": "India", "taiwan": "Taiwan",
272
+ "vietnam": "Vietnam", "russia": "Russia", "siberia": "Russia", "nepal": "Nepal",
273
+ "japan": "Japan", "sumatra": "Indonesia", "borneu": "Indonesia",
274
+ "yunnan": "China", "tibet": "China", "northern mindanao": "Philippines",
275
+ "west malaysia": "Malaysia", "north thailand": "Thailand", "central thailand": "Thailand",
276
+ "northeast thailand": "Thailand", "east myanmar": "Myanmar", "west thailand": "Thailand",
277
+ "central india": "India", "east india": "India", "northeast india": "India",
278
+ "south sibera": "Russia", "mongolia": "China", "beijing": "China", "south korea": "South Korea",
279
+ "north asia": "unknown", "southeast asia": "unknown", "east asia": "unknown"
280
+ }
281
+
282
+ def get_country_from_text(text):
283
+ text_lower = text.lower()
284
+ for keyword, country in _country_keywords.items():
285
+ if keyword in text_lower:
286
+ return country
287
+ return "unknown"
288
+ # The result will be seen as manualLink for the function get_paper_text
289
+ # def search_google_custom(query, max_results=3):
290
+ # # query should be the title from ncbi or paper/source title
291
+ # GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
292
+ # GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
293
+ # endpoint = os.environ["SEARCH_ENDPOINT"]
294
+ # params = {
295
+ # "key": GOOGLE_CSE_API_KEY,
296
+ # "cx": GOOGLE_CSE_CX,
297
+ # "q": query,
298
+ # "num": max_results
299
+ # }
300
+ # try:
301
+ # response = requests.get(endpoint, params=params)
302
+ # if response.status_code == 429:
303
+ # print("Rate limit hit. Try again later.")
304
+ # return []
305
+ # response.raise_for_status()
306
+ # data = response.json().get("items", [])
307
+ # return [item.get("link") for item in data if item.get("link")]
308
+ # except Exception as e:
309
+ # print("Google CSE error:", e)
310
+ # return []
311
+
312
+ def search_google_custom(query, max_results=3):
313
+ # query should be the title from ncbi or paper/source title
314
+ GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
315
+ GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
316
+ endpoint = os.environ["SEARCH_ENDPOINT"]
317
+ params = {
318
+ "key": GOOGLE_CSE_API_KEY,
319
+ "cx": GOOGLE_CSE_CX,
320
+ "q": query,
321
+ "num": max_results
322
+ }
323
+ try:
324
+ response = requests.get(endpoint, params=params)
325
+ if response.status_code == 429:
326
+ print("Rate limit hit. Try again later.")
327
+ print("try with back up account")
328
+ try:
329
+ return search_google_custom_backup(query, max_results)
330
+ except:
331
+ return []
332
+ response.raise_for_status()
333
+ data = response.json().get("items", [])
334
+ return [item.get("link") for item in data if item.get("link")]
335
+ except Exception as e:
336
+ print("Google CSE error:", e)
337
+ return []
338
+
339
+ def search_google_custom_backup(query, max_results=3):
340
+ # query should be the title from ncbi or paper/source title
341
+ GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY_BACKUP"]
342
+ GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX_BACKUP"]
343
+ endpoint = os.environ["SEARCH_ENDPOINT"]
344
+ params = {
345
+ "key": GOOGLE_CSE_API_KEY,
346
+ "cx": GOOGLE_CSE_CX,
347
+ "q": query,
348
+ "num": max_results
349
+ }
350
+ try:
351
+ response = requests.get(endpoint, params=params)
352
+ if response.status_code == 429:
353
+ print("Rate limit hit. Try again later.")
354
+ return []
355
+ response.raise_for_status()
356
+ data = response.json().get("items", [])
357
+ return [item.get("link") for item in data if item.get("link")]
358
+ except Exception as e:
359
+ print("Google CSE error:", e)
360
+ return []
361
+ # Step 3: Extract Text: Get the paper (html text), sup. materials (pdf, doc, excel) and do text-preprocessing
362
+ # Step 3.1: Extract Text
363
+ # sub: download excel file
364
+ def download_excel_file(url, save_path="temp.xlsx"):
365
+ if "view.officeapps.live.com" in url:
366
+ parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
367
+ real_url = urllib.parse.unquote(parsed_url["src"][0])
368
+ response = requests.get(real_url)
369
+ with open(save_path, "wb") as f:
370
+ f.write(response.content)
371
+ return save_path
372
+ elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
373
+ response = requests.get(url)
374
+ response.raise_for_status() # Raises error if download fails
375
+ with open(save_path, "wb") as f:
376
+ f.write(response.content)
377
+ return save_path
378
+ else:
379
+ print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
380
+ return url
381
+ def get_paper_text(doi,id,manualLinks=None):
382
+ # create the temporary folder to contain the texts
383
+ folder_path = Path("data/"+str(id))
384
+ if not folder_path.exists():
385
+ cmd = f'mkdir data/{id}'
386
+ result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
387
+ print("data/"+str(id) +" created.")
388
+ else:
389
+ print("data/"+str(id) +" already exists.")
390
+ saveLinkFolder = "data/"+id
391
+
392
+ link = 'https://doi.org/' + doi
393
+ '''textsToExtract = { "doiLink":"paperText"
394
+ "file1.pdf":"text1",
395
+ "file2.doc":"text2",
396
+ "file3.xlsx":excelText3'''
397
+ textsToExtract = {}
398
+ # get the file to create listOfFile for each id
399
+ html = extractHTML.HTML("",link)
400
+ jsonSM = html.getSupMaterial()
401
+ text = ""
402
+ links = [link] + sum((jsonSM[key] for key in jsonSM),[])
403
+ if manualLinks != None:
404
+ links += manualLinks
405
+ for l in links:
406
+ # get the main paper
407
+ name = l.split("/")[-1]
408
+ file_path = folder_path / name
409
+ if l == link:
410
+ text = html.getListSection()
411
+ textsToExtract[link] = text
412
+ elif l.endswith(".pdf"):
413
+ if file_path.is_file():
414
+ l = saveLinkFolder + "/" + name
415
+ print("File exists.")
416
+ p = pdf.PDF(l,saveLinkFolder,doi)
417
+ f = p.openPDFFile()
418
+ pdf_path = saveLinkFolder + "/" + l.split("/")[-1]
419
+ doc = fitz.open(pdf_path)
420
+ text = "\n".join([page.get_text() for page in doc])
421
+ textsToExtract[l] = text
422
+ elif l.endswith(".doc") or l.endswith(".docx"):
423
+ d = wordDoc.wordDoc(l,saveLinkFolder)
424
+ text = d.extractTextByPage()
425
+ textsToExtract[l] = text
426
+ elif l.split(".")[-1].lower() in "xlsx":
427
+ wc = word2vec.word2Vec()
428
+ # download excel file if it not downloaded yet
429
+ savePath = saveLinkFolder +"/"+ l.split("/")[-1]
430
+ excelPath = download_excel_file(l, savePath)
431
+ corpus = wc.tableTransformToCorpusText([],excelPath)
432
+ text = ''
433
+ for c in corpus:
434
+ para = corpus[c]
435
+ for words in para:
436
+ text += " ".join(words)
437
+ textsToExtract[l] = text
438
+ # delete folder after finishing getting text
439
+ #cmd = f'rm -r data/{id}'
440
+ #result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
441
+ return textsToExtract
442
+ # Step 3.2: Extract context
443
+ def extract_context(text, keyword, window=500):
444
+ # firstly try accession number
445
+ idx = text.find(keyword)
446
+ if idx == -1:
447
+ return "Sample ID not found."
448
+ return text[max(0, idx-window): idx+window]
449
+ def extract_relevant_paragraphs(text, accession, keep_if=None, isolate=None):
450
+ if keep_if is None:
451
+ keep_if = ["sample", "method", "mtdna", "sequence", "collected", "dataset", "supplementary", "table"]
452
+
453
+ outputs = ""
454
+ text = text.lower()
455
+
456
+ # If isolate is provided, prioritize paragraphs that mention it
457
+ # If isolate is provided, prioritize paragraphs that mention it
458
+ if accession and accession.lower() in text:
459
+ if extract_context(text, accession.lower(), window=700) != "Sample ID not found.":
460
+ outputs += extract_context(text, accession.lower(), window=700)
461
+ if isolate and isolate.lower() in text:
462
+ if extract_context(text, isolate.lower(), window=700) != "Sample ID not found.":
463
+ outputs += extract_context(text, isolate.lower(), window=700)
464
+ for keyword in keep_if:
465
+ para = extract_context(text, keyword)
466
+ if para and para not in outputs:
467
+ outputs += para + "\n"
468
+ return outputs
469
+ # Step 4: Classification for now (demo purposes)
470
+ # 4.1: Using a HuggingFace model (question-answering)
471
+ def infer_fromQAModel(context, question="Where is the mtDNA sample from?"):
472
+ try:
473
+ qa = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
474
+ result = qa({"context": context, "question": question})
475
+ return result.get("answer", "Unknown")
476
+ except Exception as e:
477
+ return f"Error: {str(e)}"
478
+
479
+ # 4.2: Infer from haplogroup
480
+ # Load pre-trained spaCy model for NER
481
+ try:
482
+ nlp = spacy.load("en_core_web_sm")
483
+ except OSError:
484
+ download("en_core_web_sm")
485
+ nlp = spacy.load("en_core_web_sm")
486
+
487
+ # Define the haplogroup-to-region mapping (simple rule-based)
488
+ import csv
489
+
490
+ def load_haplogroup_mapping(csv_path):
491
+ mapping = {}
492
+ with open(csv_path) as f:
493
+ reader = csv.DictReader(f)
494
+ for row in reader:
495
+ mapping[row["haplogroup"]] = [row["region"],row["source"]]
496
+ return mapping
497
+
498
+ # Function to extract haplogroup from the text
499
+ def extract_haplogroup(text):
500
+ match = re.search(r'\bhaplogroup\s+([A-Z][0-9a-z]*)\b', text)
501
+ if match:
502
+ submatch = re.match(r'^[A-Z][0-9]*', match.group(1))
503
+ if submatch:
504
+ return submatch.group(0)
505
+ else:
506
+ return match.group(1) # fallback
507
+ fallback = re.search(r'\b([A-Z][0-9a-z]{1,5})\b', text)
508
+ if fallback:
509
+ return fallback.group(1)
510
+ return None
511
+
512
+
513
+ # Function to extract location based on NER
514
+ def extract_location(text):
515
+ doc = nlp(text)
516
+ locations = []
517
+ for ent in doc.ents:
518
+ if ent.label_ == "GPE": # GPE = Geopolitical Entity (location)
519
+ locations.append(ent.text)
520
+ return locations
521
+
522
+ # Function to infer location from haplogroup
523
+ def infer_location_from_haplogroup(haplogroup):
524
+ haplo_map = load_haplogroup_mapping("data/haplogroup_regions_extended.csv")
525
+ return haplo_map.get(haplogroup, ["Unknown","Unknown"])
526
+
527
+ # Function to classify the mtDNA sample
528
+ def classify_mtDNA_sample_from_haplo(text):
529
+ # Extract haplogroup
530
+ haplogroup = extract_haplogroup(text)
531
+ # Extract location based on NER
532
+ locations = extract_location(text)
533
+ # Infer location based on haplogroup
534
+ inferred_location, sourceHaplo = infer_location_from_haplogroup(haplogroup)[0],infer_location_from_haplogroup(haplogroup)[1]
535
+ return {
536
+ "source":sourceHaplo,
537
+ "locations_found_in_context": locations,
538
+ "haplogroup": haplogroup,
539
+ "inferred_location": inferred_location
540
+
541
+ }
542
+ # 4.3 Get from available NCBI
543
+ def infer_location_fromNCBI(accession):
544
+ try:
545
+ handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
546
+ text = handle.read()
547
+ handle.close()
548
+ match = re.search(r'/(geo_loc_name|country|location)\s*=\s*"([^"]+)"', text)
549
+ if match:
550
+ return match.group(2), match.group(0) # This is the value like "Brunei"
551
+ return "Not found", "Not found"
552
+
553
+ except Exception as e:
554
+ print("❌ Entrez error:", e)
555
+ return "Not found", "Not found"
556
+
557
+ ### ANCIENT/MODERN FLAG
558
+ from Bio import Entrez
559
+ import re
560
+
561
+ def flag_ancient_modern(accession, textsToExtract, isolate=None):
562
+ """
563
+ Try to classify a sample as Ancient or Modern using:
564
+ 1. NCBI accession (if available)
565
+ 2. Supplementary text or context fallback
566
+ """
567
+ context = ""
568
+ label, explain = "", ""
569
+
570
+ try:
571
+ # Check if we can fetch metadata from NCBI using the accession
572
+ handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
573
+ text = handle.read()
574
+ handle.close()
575
+
576
+ isolate_source = re.search(r'/(isolation_source)\s*=\s*"([^"]+)"', text)
577
+ if isolate_source:
578
+ context += isolate_source.group(0) + " "
579
+
580
+ specimen = re.search(r'/(specimen|specimen_voucher)\s*=\s*"([^"]+)"', text)
581
+ if specimen:
582
+ context += specimen.group(0) + " "
583
+
584
+ if context.strip():
585
+ label, explain = detect_ancient_flag(context)
586
+ if label!="Unknown":
587
+ return label, explain + " from NCBI\n(" + context + ")"
588
+
589
+ # If no useful NCBI metadata, check supplementary texts
590
+ if textsToExtract:
591
+ labels = {"modern": [0, ""], "ancient": [0, ""], "unknown": 0}
592
+
593
+ for source in textsToExtract:
594
+ text_block = textsToExtract[source]
595
+ context = extract_relevant_paragraphs(text_block, accession, isolate=isolate) # Reduce to informative paragraph(s)
596
+ label, explain = detect_ancient_flag(context)
597
+
598
+ if label == "Ancient":
599
+ labels["ancient"][0] += 1
600
+ labels["ancient"][1] += f"{source}:\n{explain}\n\n"
601
+ elif label == "Modern":
602
+ labels["modern"][0] += 1
603
+ labels["modern"][1] += f"{source}:\n{explain}\n\n"
604
+ else:
605
+ labels["unknown"] += 1
606
+
607
+ if max(labels["modern"][0],labels["ancient"][0]) > 0:
608
+ if labels["modern"][0] > labels["ancient"][0]:
609
+ return "Modern", labels["modern"][1]
610
+ else:
611
+ return "Ancient", labels["ancient"][1]
612
+ else:
613
+ return "Unknown", "No strong keywords detected"
614
+ else:
615
+ print("No DOI or PubMed ID available for inference.")
616
+ return "", ""
617
+
618
+ except Exception as e:
619
+ print("Error:", e)
620
+ return "", ""
621
+
622
+
623
+ def detect_ancient_flag(context_snippet):
624
+ context = context_snippet.lower()
625
+
626
+ ancient_keywords = [
627
+ "ancient", "archaeological", "prehistoric", "neolithic", "mesolithic", "paleolithic",
628
+ "bronze age", "iron age", "burial", "tomb", "skeleton", "14c", "radiocarbon", "carbon dating",
629
+ "postmortem damage", "udg treatment", "adna", "degradation", "site", "excavation",
630
+ "archaeological context", "temporal transect", "population replacement", "cal bp", "calbp", "carbon dated"
631
+ ]
632
+
633
+ modern_keywords = [
634
+ "modern", "hospital", "clinical", "consent","blood","buccal","unrelated", "blood sample","buccal sample","informed consent", "donor", "healthy", "patient",
635
+ "genotyping", "screening", "medical", "cohort", "sequencing facility", "ethics approval",
636
+ "we analysed", "we analyzed", "dataset includes", "new sequences", "published data",
637
+ "control cohort", "sink population", "genbank accession", "sequenced", "pipeline",
638
+ "bioinformatic analysis", "samples from", "population genetics", "genome-wide data", "imr collection"
639
+ ]
640
+
641
+ ancient_hits = [k for k in ancient_keywords if k in context]
642
+ modern_hits = [k for k in modern_keywords if k in context]
643
+
644
+ if ancient_hits and not modern_hits:
645
+ return "Ancient", f"Flagged as ancient due to keywords: {', '.join(ancient_hits)}"
646
+ elif modern_hits and not ancient_hits:
647
+ return "Modern", f"Flagged as modern due to keywords: {', '.join(modern_hits)}"
648
+ elif ancient_hits and modern_hits:
649
+ if len(ancient_hits) >= len(modern_hits):
650
+ return "Ancient", f"Mixed context, leaning ancient due to: {', '.join(ancient_hits)}"
651
+ else:
652
+ return "Modern", f"Mixed context, leaning modern due to: {', '.join(modern_hits)}"
653
+
654
+ # Fallback to QA
655
+ answer = infer_fromQAModel(context, question="Are the mtDNA samples ancient or modern? Explain why.")
656
+ if answer.startswith("Error"):
657
+ return "Unknown", answer
658
+ if "ancient" in answer.lower():
659
+ return "Ancient", f"Leaning ancient based on QA: {answer}"
660
+ elif "modern" in answer.lower():
661
+ return "Modern", f"Leaning modern based on QA: {answer}"
662
+ else:
663
+ return "Unknown", f"No strong keywords or QA clues. QA said: {answer}"
664
+
665
+ # STEP 5: Main pipeline: accession -> 1. get pubmed id and isolate -> 2. get doi -> 3. get text -> 4. prediction -> 5. output: inferred location + explanation + confidence score
666
+ def classify_sample_location(accession):
667
+ outputs = {}
668
+ keyword, context, location, qa_result, haplo_result = "", "", "", "", ""
669
+ # Step 1: get pubmed id and isolate
670
+ pubmedID, isolate = get_info_from_accession(accession)
671
+ '''if not pubmedID:
672
+ return {"error": f"Could not retrieve PubMed ID for accession {accession}"}'''
673
+ if not isolate:
674
+ isolate = "UNKNOWN_ISOLATE"
675
+ # Step 2: get doi
676
+ doi = get_doi_from_pubmed_id(pubmedID)
677
+ '''if not doi:
678
+ return {"error": "DOI not found for this accession. Cannot fetch paper or context."}'''
679
+ # Step 3: get text
680
+ '''textsToExtract = { "doiLink":"paperText"
681
+ "file1.pdf":"text1",
682
+ "file2.doc":"text2",
683
+ "file3.xlsx":excelText3'''
684
+ if doi and pubmedID:
685
+ textsToExtract = get_paper_text(doi,pubmedID)
686
+ else: textsToExtract = {}
687
+ '''if not textsToExtract:
688
+ return {"error": f"No texts extracted for DOI {doi}"}'''
689
+ if isolate not in [None, "UNKNOWN_ISOLATE"]:
690
+ label, explain = flag_ancient_modern(accession,textsToExtract,isolate)
691
+ else:
692
+ label, explain = flag_ancient_modern(accession,textsToExtract)
693
+ # Step 4: prediction
694
+ outputs[accession] = {}
695
+ outputs[isolate] = {}
696
+ # 4.0 Infer from NCBI
697
+ location, outputNCBI = infer_location_fromNCBI(accession)
698
+ NCBI_result = {
699
+ "source": "NCBI",
700
+ "sample_id": accession,
701
+ "predicted_location": location,
702
+ "context_snippet": outputNCBI}
703
+ outputs[accession]["NCBI"]= {"NCBI": NCBI_result}
704
+ if textsToExtract:
705
+ long_text = ""
706
+ for key in textsToExtract:
707
+ text = textsToExtract[key]
708
+ # try accession number first
709
+ outputs[accession][key] = {}
710
+ keyword = accession
711
+ context = extract_context(text, keyword, window=500)
712
+ # 4.1: Using a HuggingFace model (question-answering)
713
+ location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
714
+ qa_result = {
715
+ "source": key,
716
+ "sample_id": keyword,
717
+ "predicted_location": location,
718
+ "context_snippet": context
719
+ }
720
+ outputs[keyword][key]["QAModel"] = qa_result
721
+ # 4.2: Infer from haplogroup
722
+ haplo_result = classify_mtDNA_sample_from_haplo(context)
723
+ outputs[keyword][key]["haplogroup"] = haplo_result
724
+ # try isolate
725
+ keyword = isolate
726
+ outputs[isolate][key] = {}
727
+ context = extract_context(text, keyword, window=500)
728
+ # 4.1.1: Using a HuggingFace model (question-answering)
729
+ location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
730
+ qa_result = {
731
+ "source": key,
732
+ "sample_id": keyword,
733
+ "predicted_location": location,
734
+ "context_snippet": context
735
+ }
736
+ outputs[keyword][key]["QAModel"] = qa_result
737
+ # 4.2.1: Infer from haplogroup
738
+ haplo_result = classify_mtDNA_sample_from_haplo(context)
739
+ outputs[keyword][key]["haplogroup"] = haplo_result
740
+ # add long text
741
+ long_text += text + ". \n"
742
+ # 4.3: UpgradeClassify
743
+ # try sample_id as accession number
744
+ sample_id = accession
745
+ if sample_id:
746
+ filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
747
+ locations = infer_location_for_sample(sample_id.upper(), filtered_context)
748
+ if locations!="No clear location found in top matches":
749
+ outputs[sample_id]["upgradeClassifier"] = {}
750
+ outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
751
+ "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
752
+ "sample_id": sample_id,
753
+ "predicted_location": ", ".join(locations),
754
+ "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
755
+ }
756
+ # try sample_id as isolate name
757
+ sample_id = isolate
758
+ if sample_id:
759
+ filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
760
+ locations = infer_location_for_sample(sample_id.upper(), filtered_context)
761
+ if locations!="No clear location found in top matches":
762
+ outputs[sample_id]["upgradeClassifier"] = {}
763
+ outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
764
+ "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
765
+ "sample_id": sample_id,
766
+ "predicted_location": ", ".join(locations),
767
+ "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
768
+ }
769
  return outputs, label, explain