github-actions[bot] commited on
Commit
31df32c
·
1 Parent(s): ac15317

Auto-sync from demo at Thu Dec 25 14:16:44 UTC 2025

Browse files
graphgen/bases/base_operator.py CHANGED
@@ -6,11 +6,12 @@ from typing import Iterable, Union
6
  import pandas as pd
7
  import ray
8
 
9
- from graphgen.utils import CURRENT_LOGGER_VAR, set_logger
10
-
11
 
12
  class BaseOperator(ABC):
13
  def __init__(self, working_dir: str = "cache", op_name: str = None):
 
 
 
14
  log_dir = os.path.join(working_dir, "logs")
15
  self.op_name = op_name or self.__class__.__name__
16
 
@@ -39,6 +40,9 @@ class BaseOperator(ABC):
39
  def __call__(
40
  self, batch: pd.DataFrame
41
  ) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]:
 
 
 
42
  logger_token = CURRENT_LOGGER_VAR.set(self.logger)
43
  try:
44
  result = self.process(batch)
 
6
  import pandas as pd
7
  import ray
8
 
 
 
9
 
10
  class BaseOperator(ABC):
11
  def __init__(self, working_dir: str = "cache", op_name: str = None):
12
+ # lazy import to avoid circular import
13
+ from graphgen.utils import set_logger
14
+
15
  log_dir = os.path.join(working_dir, "logs")
16
  self.op_name = op_name or self.__class__.__name__
17
 
 
40
  def __call__(
41
  self, batch: pd.DataFrame
42
  ) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]:
43
+ # lazy import to avoid circular import
44
+ from graphgen.utils import CURRENT_LOGGER_VAR
45
+
46
  logger_token = CURRENT_LOGGER_VAR.set(self.logger)
47
  try:
48
  result = self.process(batch)
graphgen/bases/base_reader.py CHANGED
@@ -39,6 +39,8 @@ class BaseReader(ABC):
39
  "table",
40
  "equation",
41
  "protein",
 
 
42
  ], f"Unsupported item type: {item_type}"
43
  if item_type == "text":
44
  content = item.get(self.text_column, "").strip()
 
39
  "table",
40
  "equation",
41
  "protein",
42
+ "dna",
43
+ "rna",
44
  ], f"Unsupported item type: {item_type}"
45
  if item_type == "text":
46
  content = item.get(self.text_column, "").strip()
graphgen/bases/base_searcher.py CHANGED
@@ -1,5 +1,5 @@
1
  from abc import ABC, abstractmethod
2
- from typing import Any, Dict, List
3
 
4
 
5
  class BaseSearcher(ABC):
@@ -8,11 +8,11 @@ class BaseSearcher(ABC):
8
  """
9
 
10
  @abstractmethod
11
- async def search(self, query: str, **kwargs) -> List[Dict[str, Any]]:
12
  """
13
  Search for data based on the given query.
14
 
15
  :param query: The searcher query.
16
  :param kwargs: Additional keyword arguments for the searcher.
17
- :return: List of dictionaries containing the searcher results.
18
  """
 
1
  from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, Optional
3
 
4
 
5
  class BaseSearcher(ABC):
 
8
  """
9
 
10
  @abstractmethod
11
+ def search(self, query: str, **kwargs) -> Optional[Dict[str, Any]]:
12
  """
13
  Search for data based on the given query.
14
 
15
  :param query: The searcher query.
16
  :param kwargs: Additional keyword arguments for the searcher.
17
+ :return: Dictionary containing the searcher result, or None if not found.
18
  """
graphgen/models/searcher/db/ncbi_searcher.py CHANGED
@@ -1,10 +1,7 @@
1
- import asyncio
2
  import os
3
  import re
4
  import subprocess
5
  import tempfile
6
- from concurrent.futures import ThreadPoolExecutor
7
- from functools import lru_cache
8
  from http.client import IncompleteRead
9
  from typing import Dict, Optional
10
 
@@ -22,15 +19,6 @@ from graphgen.bases import BaseSearcher
22
  from graphgen.utils import logger
23
 
24
 
25
- @lru_cache(maxsize=None)
26
- def _get_pool():
27
- return ThreadPoolExecutor(max_workers=10)
28
-
29
-
30
- # ensure only one NCBI request at a time
31
- _ncbi_lock = asyncio.Lock()
32
-
33
-
34
  class NCBISearch(BaseSearcher):
35
  """
36
  NCBI Search client to search DNA/GenBank/Entrez databases.
@@ -49,6 +37,8 @@ class NCBISearch(BaseSearcher):
49
  email: str = "email@example.com",
50
  api_key: str = "",
51
  tool: str = "GraphGen",
 
 
52
  ):
53
  """
54
  Initialize the NCBI Search client.
@@ -59,8 +49,8 @@ class NCBISearch(BaseSearcher):
59
  email (str): Email address for NCBI API requests.
60
  api_key (str): API key for NCBI API requests, see https://account.ncbi.nlm.nih.gov/settings/.
61
  tool (str): Tool name for NCBI API requests.
 
62
  """
63
- super().__init__()
64
  Entrez.timeout = 60 # 60 seconds timeout
65
  Entrez.email = email
66
  Entrez.tool = tool
@@ -70,9 +60,23 @@ class NCBISearch(BaseSearcher):
70
  Entrez.sleep_between_tries = 5
71
  self.use_local_blast = use_local_blast
72
  self.local_blast_db = local_blast_db
73
- if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.nhr"):
74
- logger.error("Local BLAST database files not found. Please check the path.")
75
- self.use_local_blast = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  @staticmethod
78
  def _nested_get(data: dict, *keys, default=None):
@@ -84,17 +88,21 @@ class NCBISearch(BaseSearcher):
84
  return data
85
 
86
  @staticmethod
87
- def _infer_molecule_type_detail(accession: Optional[str], gene_type: Optional[int] = None) -> Optional[str]:
 
 
88
  """Infer molecule_type_detail from accession prefix or gene type."""
89
  if accession:
90
- if accession.startswith(("NM_", "XM_")):
91
- return "mRNA"
92
- if accession.startswith(("NC_", "NT_")):
93
- return "genomic DNA"
94
- if accession.startswith(("NR_", "XR_")):
95
- return "RNA"
96
- if accession.startswith("NG_"):
97
- return "genomic region"
 
 
98
  # Fallback: infer from gene type if available
99
  if gene_type is not None:
100
  gene_type_map = {
@@ -126,20 +134,25 @@ class NCBISearch(BaseSearcher):
126
  gene_synonyms = []
127
  if isinstance(synonyms_raw, list):
128
  for syn in synonyms_raw:
129
- gene_synonyms.append(syn.get("Gene-ref_syn_E") if isinstance(syn, dict) else str(syn))
 
 
130
  elif synonyms_raw:
131
  gene_synonyms.append(str(synonyms_raw))
132
 
133
  # Extract location info
134
  label = locus.get("Gene-commentary_label", "")
135
- chromosome_match = re.search(r"Chromosome\s+(\S+)", str(label)) if label else None
 
 
136
 
137
  seq_interval = self._nested_get(
138
  locus, "Gene-commentary_seqs", 0, "Seq-loc_int", "Seq-interval", default={}
139
  )
140
  genomic_location = (
141
  f"{seq_interval.get('Seq-interval_from')}-{seq_interval.get('Seq-interval_to')}"
142
- if seq_interval.get('Seq-interval_from') and seq_interval.get('Seq-interval_to')
 
143
  else None
144
  )
145
 
@@ -153,7 +166,6 @@ class NCBISearch(BaseSearcher):
153
  None,
154
  )
155
  # Fallback: if no type 3 accession, try any available accession
156
- # This is needed for genes that don't have mRNA transcripts but have other sequence records
157
  if not representative_accession:
158
  representative_accession = next(
159
  (
@@ -170,7 +182,8 @@ class NCBISearch(BaseSearcher):
170
  comment.get("Gene-commentary_comment")
171
  for comment in data.get("Entrezgene_comments", [])
172
  if isinstance(comment, dict)
173
- and "function" in str(comment.get("Gene-commentary_heading", "")).lower()
 
174
  ),
175
  None,
176
  )
@@ -194,7 +207,9 @@ class NCBISearch(BaseSearcher):
194
  "5": "snRNA",
195
  "6": "ncRNA",
196
  "7": "other",
197
- }.get(str(data.get("Entrezgene_type")), f"type_{data.get('Entrezgene_type')}"),
 
 
198
  "chromosome": chromosome_match.group(1) if chromosome_match else None,
199
  "genomic_location": genomic_location,
200
  "function": function,
@@ -209,25 +224,33 @@ class NCBISearch(BaseSearcher):
209
  "_representative_accession": representative_accession,
210
  }
211
 
212
- def get_by_gene_id(self, gene_id: str, preferred_accession: Optional[str] = None) -> Optional[dict]:
 
 
 
 
 
 
 
 
213
  """Get gene information by Gene ID."""
 
214
  def _extract_metadata_from_genbank(result: dict, accession: str):
215
  """Extract metadata from GenBank format (title, features, organism, etc.)."""
216
- with Entrez.efetch(db="nuccore", id=accession, rettype="gb", retmode="text") as handle:
 
 
217
  record = SeqIO.read(handle, "genbank")
218
 
219
  result["title"] = record.description
220
  result["molecule_type_detail"] = (
221
- "mRNA" if accession.startswith(("NM_", "XM_")) else
222
- "genomic DNA" if accession.startswith(("NC_", "NT_")) else
223
- "RNA" if accession.startswith(("NR_", "XR_")) else
224
- "genomic region" if accession.startswith("NG_") else "N/A"
225
  )
226
 
227
  for feature in record.features:
228
  if feature.type == "source":
229
- if 'chromosome' in feature.qualifiers:
230
- result["chromosome"] = feature.qualifiers['chromosome'][0]
231
 
232
  if feature.location:
233
  start = int(feature.location.start) + 1
@@ -236,48 +259,91 @@ class NCBISearch(BaseSearcher):
236
 
237
  break
238
 
239
- if not result.get("organism") and 'organism' in record.annotations:
240
- result["organism"] = record.annotations['organism']
241
 
242
  return result
243
 
244
  def _extract_sequence_from_fasta(result: dict, accession: str):
245
  """Extract sequence from FASTA format (more reliable than GenBank for CON-type records)."""
246
  try:
247
- with Entrez.efetch(db="nuccore", id=accession, rettype="fasta", retmode="text") as fasta_handle:
 
 
248
  fasta_record = SeqIO.read(fasta_handle, "fasta")
249
  result["sequence"] = str(fasta_record.seq)
250
  result["sequence_length"] = len(fasta_record.seq)
251
  except Exception as fasta_exc:
252
  logger.warning(
253
  "Failed to extract sequence from accession %s using FASTA format: %s",
254
- accession, fasta_exc
 
255
  )
256
  result["sequence"] = None
257
  result["sequence_length"] = None
258
  return result
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  try:
261
  with Entrez.efetch(db="gene", id=gene_id, retmode="xml") as handle:
262
  gene_record = Entrez.read(handle)
263
- if not gene_record:
264
- return None
265
 
266
- result = self._gene_record_to_dict(gene_record, gene_id)
267
- if accession := (preferred_accession or result.get("_representative_accession")):
268
- result = _extract_metadata_from_genbank(result, accession)
269
- result = _extract_sequence_from_fasta(result, accession)
 
 
 
 
 
 
 
270
 
271
- result.pop("_representative_accession", None)
272
- return result
273
  except (RequestException, IncompleteRead):
274
  raise
275
  except Exception as exc:
276
  logger.error("Gene ID %s not found: %s", gene_id, exc)
277
  return None
278
 
 
 
 
 
 
 
279
  def get_by_accession(self, accession: str) -> Optional[dict]:
280
  """Get sequence information by accession number."""
 
281
  def _extract_gene_id(link_handle):
282
  """Extract GeneID from elink results."""
283
  links = Entrez.read(link_handle)
@@ -301,9 +367,11 @@ class NCBISearch(BaseSearcher):
301
  return None
302
 
303
  result = self.get_by_gene_id(gene_id, preferred_accession=accession)
 
304
  if result:
305
  result["id"] = accession
306
  result["url"] = f"https://www.ncbi.nlm.nih.gov/nuccore/{accession}"
 
307
  return result
308
  except (RequestException, IncompleteRead):
309
  raise
@@ -311,6 +379,12 @@ class NCBISearch(BaseSearcher):
311
  logger.error("Accession %s not found: %s", accession, exc)
312
  return None
313
 
 
 
 
 
 
 
314
  def get_best_hit(self, keyword: str) -> Optional[dict]:
315
  """Search NCBI Gene database with a keyword and return the best hit."""
316
  if not keyword.strip():
@@ -318,33 +392,113 @@ class NCBISearch(BaseSearcher):
318
 
319
  try:
320
  for search_term in [f"{keyword}[Gene] OR {keyword}[All Fields]", keyword]:
321
- with Entrez.esearch(db="gene", term=search_term, retmax=1, sort="relevance") as search_handle:
 
 
322
  search_results = Entrez.read(search_handle)
323
- if len(gene_id := search_results.get("IdList", [])) > 0:
324
- return self.get_by_gene_id(gene_id)
 
 
325
  except (RequestException, IncompleteRead):
326
  raise
327
  except Exception as e:
328
  logger.error("Keyword %s not found: %s", keyword, e)
329
  return None
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
332
- """Perform local BLAST search using local BLAST database."""
 
 
 
333
  try:
334
- with tempfile.NamedTemporaryFile(mode="w+", suffix=".fa", delete=False) as tmp:
 
 
335
  tmp.write(f">query\n{seq}\n")
336
  tmp_name = tmp.name
337
 
 
 
 
 
 
338
  cmd = [
339
- "blastn", "-db", self.local_blast_db, "-query", tmp_name,
340
- "-evalue", str(threshold), "-max_target_seqs", "1", "-outfmt", "6 sacc"
 
 
 
 
 
 
 
 
 
 
 
341
  ]
342
- logger.debug("Running local blastn: %s", " ".join(cmd))
343
- out = subprocess.check_output(cmd, text=True).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  os.remove(tmp_name)
345
  return out.split("\n", maxsplit=1)[0] if out else None
346
  except Exception as exc:
347
  logger.error("Local blastn failed: %s", exc)
 
 
 
 
 
 
348
  return None
349
 
350
  def get_by_fasta(self, sequence: str, threshold: float = 0.01) -> Optional[dict]:
@@ -358,8 +512,9 @@ class NCBISearch(BaseSearcher):
358
  seq = sequence.strip().replace(" ", "").replace("\n", "")
359
  return seq if re.fullmatch(r"[ATCGN]+", seq, re.I) else None
360
 
361
-
362
- def _process_network_blast_result(blast_record, seq: str, threshold: float) -> Optional[dict]:
 
363
  """Process network BLAST result and return dictionary or None."""
364
  if not blast_record.alignments:
365
  logger.info("No BLAST hits found for the given sequence.")
@@ -383,7 +538,9 @@ class NCBISearch(BaseSearcher):
383
  "title": best_alignment.title,
384
  "sequence_length": len(seq),
385
  "e_value": best_hsp.expect,
386
- "identity": best_hsp.identities / best_hsp.align_length if best_hsp.align_length > 0 else 0,
 
 
387
  "url": f"https://www.ncbi.nlm.nih.gov/nuccore/{hit_id}",
388
  }
389
 
@@ -393,15 +550,31 @@ class NCBISearch(BaseSearcher):
393
  return None
394
 
395
  # Try local BLAST first if enabled
396
- if self.use_local_blast and (accession := self._local_blast(seq, threshold)):
397
- logger.debug("Local BLAST found accession: %s", accession)
398
- return self.get_by_accession(accession)
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
- # Fall back to network BLAST
401
  logger.debug("Falling back to NCBIWWW.qblast")
402
-
403
- with NCBIWWW.qblast("blastn", "nr", seq, hitlist_size=1, expect=threshold) as result_handle:
404
- return _process_network_blast_result(NCBIXML.read(result_handle), seq, threshold)
 
 
 
 
405
  except (RequestException, IncompleteRead):
406
  raise
407
  except Exception as e:
@@ -414,8 +587,9 @@ class NCBISearch(BaseSearcher):
414
  retry=retry_if_exception_type((RequestException, IncompleteRead)),
415
  reraise=True,
416
  )
417
- async def search(self, query: str, threshold: float = 0.01, **kwargs) -> Optional[Dict]:
418
  """Search NCBI with either a gene ID, accession number, keyword, or DNA sequence."""
 
419
  if not query or not isinstance(query, str):
420
  logger.error("Empty or non-string input.")
421
  return None
@@ -423,19 +597,21 @@ class NCBISearch(BaseSearcher):
423
  query = query.strip()
424
  logger.debug("NCBI search query: %s", query)
425
 
426
- loop = asyncio.get_running_loop()
427
-
428
- # limit concurrent requests (NCBI rate limit: max 3 requests per second)
429
- async with _ncbi_lock:
430
- # Auto-detect query type and execute in thread pool
431
- if query.startswith(">") or re.fullmatch(r"[ATCGN\s]+", query, re.I):
432
- result = await loop.run_in_executor(_get_pool(), self.get_by_fasta, query, threshold)
433
- elif re.fullmatch(r"^\d+$", query):
434
- result = await loop.run_in_executor(_get_pool(), self.get_by_gene_id, query)
435
- elif re.fullmatch(r"[A-Z]{2}_\d+\.?\d*", query, re.I):
436
- result = await loop.run_in_executor(_get_pool(), self.get_by_accession, query)
437
- else:
438
- result = await loop.run_in_executor(_get_pool(), self.get_best_hit, query)
 
 
439
 
440
  if result:
441
  result["_search_query"] = query
 
 
1
  import os
2
  import re
3
  import subprocess
4
  import tempfile
 
 
5
  from http.client import IncompleteRead
6
  from typing import Dict, Optional
7
 
 
19
  from graphgen.utils import logger
20
 
21
 
 
 
 
 
 
 
 
 
 
22
  class NCBISearch(BaseSearcher):
23
  """
24
  NCBI Search client to search DNA/GenBank/Entrez databases.
 
37
  email: str = "email@example.com",
38
  api_key: str = "",
39
  tool: str = "GraphGen",
40
+ blast_num_threads: int = 4,
41
+ threshold: float = 0.01,
42
  ):
43
  """
44
  Initialize the NCBI Search client.
 
49
  email (str): Email address for NCBI API requests.
50
  api_key (str): API key for NCBI API requests, see https://account.ncbi.nlm.nih.gov/settings/.
51
  tool (str): Tool name for NCBI API requests.
52
+ blast_num_threads (int): Number of threads for BLAST search.
53
  """
 
54
  Entrez.timeout = 60 # 60 seconds timeout
55
  Entrez.email = email
56
  Entrez.tool = tool
 
60
  Entrez.sleep_between_tries = 5
61
  self.use_local_blast = use_local_blast
62
  self.local_blast_db = local_blast_db
63
+ self.blast_num_threads = blast_num_threads
64
+ self.threshold = threshold
65
+ if self.use_local_blast:
66
+ # Check for single-file database (.nhr) or multi-file database (.00.nhr)
67
+ db_exists = os.path.isfile(f"{self.local_blast_db}.nhr") or os.path.isfile(
68
+ f"{self.local_blast_db}.00.nhr"
69
+ )
70
+ if not db_exists:
71
+ logger.error(
72
+ "Local BLAST database files not found. Please check the path."
73
+ )
74
+ logger.error(
75
+ "Expected: %s.nhr or %s.00.nhr",
76
+ self.local_blast_db,
77
+ self.local_blast_db,
78
+ )
79
+ self.use_local_blast = False
80
 
81
  @staticmethod
82
  def _nested_get(data: dict, *keys, default=None):
 
88
  return data
89
 
90
  @staticmethod
91
+ def _infer_molecule_type_detail(
92
+ accession: Optional[str], gene_type: Optional[int] = None
93
+ ) -> Optional[str]:
94
  """Infer molecule_type_detail from accession prefix or gene type."""
95
  if accession:
96
+ # Map accession prefixes to molecule types
97
+ prefix_map = {
98
+ ("NM_", "XM_"): "mRNA",
99
+ ("NC_", "NT_"): "genomic DNA",
100
+ ("NR_", "XR_"): "RNA",
101
+ ("NG_",): "genomic region",
102
+ }
103
+ for prefixes, mol_type in prefix_map.items():
104
+ if accession.startswith(prefixes):
105
+ return mol_type
106
  # Fallback: infer from gene type if available
107
  if gene_type is not None:
108
  gene_type_map = {
 
134
  gene_synonyms = []
135
  if isinstance(synonyms_raw, list):
136
  for syn in synonyms_raw:
137
+ gene_synonyms.append(
138
+ syn.get("Gene-ref_syn_E") if isinstance(syn, dict) else str(syn)
139
+ )
140
  elif synonyms_raw:
141
  gene_synonyms.append(str(synonyms_raw))
142
 
143
  # Extract location info
144
  label = locus.get("Gene-commentary_label", "")
145
+ chromosome_match = (
146
+ re.search(r"Chromosome\s+(\S+)", str(label)) if label else None
147
+ )
148
 
149
  seq_interval = self._nested_get(
150
  locus, "Gene-commentary_seqs", 0, "Seq-loc_int", "Seq-interval", default={}
151
  )
152
  genomic_location = (
153
  f"{seq_interval.get('Seq-interval_from')}-{seq_interval.get('Seq-interval_to')}"
154
+ if seq_interval.get("Seq-interval_from")
155
+ and seq_interval.get("Seq-interval_to")
156
  else None
157
  )
158
 
 
166
  None,
167
  )
168
  # Fallback: if no type 3 accession, try any available accession
 
169
  if not representative_accession:
170
  representative_accession = next(
171
  (
 
182
  comment.get("Gene-commentary_comment")
183
  for comment in data.get("Entrezgene_comments", [])
184
  if isinstance(comment, dict)
185
+ and "function"
186
+ in str(comment.get("Gene-commentary_heading", "")).lower()
187
  ),
188
  None,
189
  )
 
207
  "5": "snRNA",
208
  "6": "ncRNA",
209
  "7": "other",
210
+ }.get(
211
+ str(data.get("Entrezgene_type")), f"type_{data.get('Entrezgene_type')}"
212
+ ),
213
  "chromosome": chromosome_match.group(1) if chromosome_match else None,
214
  "genomic_location": genomic_location,
215
  "function": function,
 
224
  "_representative_accession": representative_accession,
225
  }
226
 
227
+ @retry(
228
+ stop=stop_after_attempt(5),
229
+ wait=wait_exponential(multiplier=1, min=4, max=10),
230
+ retry=retry_if_exception_type((RequestException, IncompleteRead)),
231
+ reraise=True,
232
+ )
233
+ def get_by_gene_id(
234
+ self, gene_id: str, preferred_accession: Optional[str] = None
235
+ ) -> Optional[dict]:
236
  """Get gene information by Gene ID."""
237
+
238
  def _extract_metadata_from_genbank(result: dict, accession: str):
239
  """Extract metadata from GenBank format (title, features, organism, etc.)."""
240
+ with Entrez.efetch(
241
+ db="nuccore", id=accession, rettype="gb", retmode="text"
242
+ ) as handle:
243
  record = SeqIO.read(handle, "genbank")
244
 
245
  result["title"] = record.description
246
  result["molecule_type_detail"] = (
247
+ self._infer_molecule_type_detail(accession) or "N/A"
 
 
 
248
  )
249
 
250
  for feature in record.features:
251
  if feature.type == "source":
252
+ if "chromosome" in feature.qualifiers:
253
+ result["chromosome"] = feature.qualifiers["chromosome"][0]
254
 
255
  if feature.location:
256
  start = int(feature.location.start) + 1
 
259
 
260
  break
261
 
262
+ if not result.get("organism") and "organism" in record.annotations:
263
+ result["organism"] = record.annotations["organism"]
264
 
265
  return result
266
 
267
  def _extract_sequence_from_fasta(result: dict, accession: str):
268
  """Extract sequence from FASTA format (more reliable than GenBank for CON-type records)."""
269
  try:
270
+ with Entrez.efetch(
271
+ db="nuccore", id=accession, rettype="fasta", retmode="text"
272
+ ) as fasta_handle:
273
  fasta_record = SeqIO.read(fasta_handle, "fasta")
274
  result["sequence"] = str(fasta_record.seq)
275
  result["sequence_length"] = len(fasta_record.seq)
276
  except Exception as fasta_exc:
277
  logger.warning(
278
  "Failed to extract sequence from accession %s using FASTA format: %s",
279
+ accession,
280
+ fasta_exc,
281
  )
282
  result["sequence"] = None
283
  result["sequence_length"] = None
284
  return result
285
 
286
+ def _extract_sequence(result: dict, accession: str):
287
+ """
288
+ Extract sequence using the appropriate method based on configuration.
289
+ If use_local_blast=True, use local database. Otherwise, use NCBI API.
290
+ Always fetches sequence (no option to skip).
291
+ """
292
+ # If using local BLAST, use local database
293
+ if self.use_local_blast:
294
+ sequence = self._extract_sequence_from_local_db(accession)
295
+
296
+ if sequence:
297
+ result["sequence"] = sequence
298
+ result["sequence_length"] = len(sequence)
299
+ else:
300
+ # Failed to extract from local DB, set to None (no fallback to API)
301
+ result["sequence"] = None
302
+ result["sequence_length"] = None
303
+ logger.warning(
304
+ "Failed to extract sequence from local DB for accession %s. "
305
+ "Not falling back to NCBI API as use_local_blast=True.",
306
+ accession,
307
+ )
308
+ else:
309
+ # Use NCBI API to fetch sequence
310
+ result = _extract_sequence_from_fasta(result, accession)
311
+
312
+ return result
313
+
314
  try:
315
  with Entrez.efetch(db="gene", id=gene_id, retmode="xml") as handle:
316
  gene_record = Entrez.read(handle)
 
 
317
 
318
+ if not gene_record:
319
+ return None
320
+
321
+ result = self._gene_record_to_dict(gene_record, gene_id)
322
+
323
+ if accession := (
324
+ preferred_accession or result.get("_representative_accession")
325
+ ):
326
+ result = _extract_metadata_from_genbank(result, accession)
327
+ # Extract sequence using appropriate method
328
+ result = _extract_sequence(result, accession)
329
 
330
+ result.pop("_representative_accession", None)
331
+ return result
332
  except (RequestException, IncompleteRead):
333
  raise
334
  except Exception as exc:
335
  logger.error("Gene ID %s not found: %s", gene_id, exc)
336
  return None
337
 
338
+ @retry(
339
+ stop=stop_after_attempt(5),
340
+ wait=wait_exponential(multiplier=1, min=4, max=10),
341
+ retry=retry_if_exception_type((RequestException, IncompleteRead)),
342
+ reraise=True,
343
+ )
344
  def get_by_accession(self, accession: str) -> Optional[dict]:
345
  """Get sequence information by accession number."""
346
+
347
  def _extract_gene_id(link_handle):
348
  """Extract GeneID from elink results."""
349
  links = Entrez.read(link_handle)
 
367
  return None
368
 
369
  result = self.get_by_gene_id(gene_id, preferred_accession=accession)
370
+
371
  if result:
372
  result["id"] = accession
373
  result["url"] = f"https://www.ncbi.nlm.nih.gov/nuccore/{accession}"
374
+
375
  return result
376
  except (RequestException, IncompleteRead):
377
  raise
 
379
  logger.error("Accession %s not found: %s", accession, exc)
380
  return None
381
 
382
+ @retry(
383
+ stop=stop_after_attempt(5),
384
+ wait=wait_exponential(multiplier=1, min=4, max=10),
385
+ retry=retry_if_exception_type((RequestException, IncompleteRead)),
386
+ reraise=True,
387
+ )
388
  def get_best_hit(self, keyword: str) -> Optional[dict]:
389
  """Search NCBI Gene database with a keyword and return the best hit."""
390
  if not keyword.strip():
 
392
 
393
  try:
394
  for search_term in [f"{keyword}[Gene] OR {keyword}[All Fields]", keyword]:
395
+ with Entrez.esearch(
396
+ db="gene", term=search_term, retmax=1, sort="relevance"
397
+ ) as search_handle:
398
  search_results = Entrez.read(search_handle)
399
+
400
+ if len(gene_id := search_results.get("IdList", [])) > 0:
401
+ result = self.get_by_gene_id(gene_id[0])
402
+ return result
403
  except (RequestException, IncompleteRead):
404
  raise
405
  except Exception as e:
406
  logger.error("Keyword %s not found: %s", keyword, e)
407
  return None
408
 
409
+ def _extract_sequence_from_local_db(self, accession: str) -> Optional[str]:
410
+ """Extract sequence from local BLAST database using blastdbcmd."""
411
+ try:
412
+ cmd = [
413
+ "blastdbcmd",
414
+ "-db",
415
+ self.local_blast_db,
416
+ "-entry",
417
+ accession,
418
+ "-outfmt",
419
+ "%s", # Only sequence, no header
420
+ ]
421
+ sequence = subprocess.check_output(
422
+ cmd,
423
+ text=True,
424
+ timeout=10, # 10 second timeout for local extraction
425
+ stderr=subprocess.DEVNULL,
426
+ ).strip()
427
+ return sequence if sequence else None
428
+ except subprocess.TimeoutExpired:
429
+ logger.warning(
430
+ "Timeout extracting sequence from local DB for accession %s", accession
431
+ )
432
+ return None
433
+ except Exception as exc:
434
+ logger.warning(
435
+ "Failed to extract sequence from local DB for accession %s: %s",
436
+ accession,
437
+ exc,
438
+ )
439
+ return None
440
+
441
  def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
442
+ """
443
+ Perform local BLAST search using local BLAST database.
444
+ Optimized with multi-threading and faster output format.
445
+ """
446
  try:
447
+ with tempfile.NamedTemporaryFile(
448
+ mode="w+", suffix=".fa", delete=False
449
+ ) as tmp:
450
  tmp.write(f">query\n{seq}\n")
451
  tmp_name = tmp.name
452
 
453
+ # Optimized BLAST command with:
454
+ # - num_threads: Use multiple threads for faster search
455
+ # - outfmt 6 sacc: Only return accession (minimal output)
456
+ # - max_target_seqs 1: Only need the best hit
457
+ # - evalue: Threshold for significance
458
  cmd = [
459
+ "blastn",
460
+ "-db",
461
+ self.local_blast_db,
462
+ "-query",
463
+ tmp_name,
464
+ "-evalue",
465
+ str(threshold),
466
+ "-max_target_seqs",
467
+ "1",
468
+ "-num_threads",
469
+ str(self.blast_num_threads),
470
+ "-outfmt",
471
+ "6 sacc", # Only accession, tab-separated
472
  ]
473
+ logger.debug(
474
+ "Running local blastn (threads=%d): %s",
475
+ self.blast_num_threads,
476
+ " ".join(cmd),
477
+ )
478
+
479
+ # Run BLAST with timeout to avoid hanging
480
+ try:
481
+ out = subprocess.check_output(
482
+ cmd,
483
+ text=True,
484
+ timeout=300, # 5 minute timeout for BLAST search
485
+ stderr=subprocess.DEVNULL, # Suppress BLAST warnings to reduce I/O
486
+ ).strip()
487
+ except subprocess.TimeoutExpired:
488
+ logger.warning("BLAST search timed out after 5 minutes for sequence")
489
+ os.remove(tmp_name)
490
+ return None
491
+
492
  os.remove(tmp_name)
493
  return out.split("\n", maxsplit=1)[0] if out else None
494
  except Exception as exc:
495
  logger.error("Local blastn failed: %s", exc)
496
+ # Clean up temp file if it still exists
497
+ try:
498
+ if "tmp_name" in locals():
499
+ os.remove(tmp_name)
500
+ except Exception:
501
+ pass
502
  return None
503
 
504
  def get_by_fasta(self, sequence: str, threshold: float = 0.01) -> Optional[dict]:
 
512
  seq = sequence.strip().replace(" ", "").replace("\n", "")
513
  return seq if re.fullmatch(r"[ATCGN]+", seq, re.I) else None
514
 
515
+ def _process_network_blast_result(
516
+ blast_record, seq: str, threshold: float
517
+ ) -> Optional[dict]:
518
  """Process network BLAST result and return dictionary or None."""
519
  if not blast_record.alignments:
520
  logger.info("No BLAST hits found for the given sequence.")
 
538
  "title": best_alignment.title,
539
  "sequence_length": len(seq),
540
  "e_value": best_hsp.expect,
541
+ "identity": best_hsp.identities / best_hsp.align_length
542
+ if best_hsp.align_length > 0
543
+ else 0,
544
  "url": f"https://www.ncbi.nlm.nih.gov/nuccore/{hit_id}",
545
  }
546
 
 
550
  return None
551
 
552
  # Try local BLAST first if enabled
553
+ if self.use_local_blast:
554
+ accession = self._local_blast(seq, threshold)
555
+
556
+ if accession:
557
+ logger.debug("Local BLAST found accession: %s", accession)
558
+ # When using local BLAST, skip sequence fetching by default (faster, fewer API calls)
559
+ # Sequence is already known from the query, so we only need metadata
560
+ result = self.get_by_accession(accession)
561
+ return result
562
+
563
+ logger.info(
564
+ "Local BLAST found no match for sequence. "
565
+ "API fallback disabled when using local database."
566
+ )
567
+ return None
568
 
569
+ # Fall back to network BLAST only if local BLAST is not enabled
570
  logger.debug("Falling back to NCBIWWW.qblast")
571
+ with NCBIWWW.qblast(
572
+ "blastn", "nr", seq, hitlist_size=1, expect=threshold
573
+ ) as result_handle:
574
+ result = _process_network_blast_result(
575
+ NCBIXML.read(result_handle), seq, threshold
576
+ )
577
+ return result
578
  except (RequestException, IncompleteRead):
579
  raise
580
  except Exception as e:
 
587
  retry=retry_if_exception_type((RequestException, IncompleteRead)),
588
  reraise=True,
589
  )
590
+ def search(self, query: str, threshold: float = None, **kwargs) -> Optional[Dict]:
591
  """Search NCBI with either a gene ID, accession number, keyword, or DNA sequence."""
592
+ threshold = threshold or self.threshold
593
  if not query or not isinstance(query, str):
594
  logger.error("Empty or non-string input.")
595
  return None
 
597
  query = query.strip()
598
  logger.debug("NCBI search query: %s", query)
599
 
600
+ # Auto-detect query type and execute
601
+ # All methods call NCBI API (rate limit: max 3 requests per second)
602
+ # Even if get_by_fasta uses local BLAST, it still calls get_by_accession which needs API
603
+ if query.startswith(">") or re.fullmatch(r"[ATCGN\s]+", query, re.I):
604
+ # FASTA sequence
605
+ result = self.get_by_fasta(query, threshold)
606
+ elif re.fullmatch(r"^\d+$", query):
607
+ # Gene ID
608
+ result = self.get_by_gene_id(query)
609
+ elif re.fullmatch(r"[A-Z]{2}_\d+\.?\d*", query, re.I):
610
+ # Accession
611
+ result = self.get_by_accession(query)
612
+ else:
613
+ # Keyword
614
+ result = self.get_best_hit(query)
615
 
616
  if result:
617
  result["_search_query"] = query
graphgen/models/searcher/db/rnacentral_searcher.py CHANGED
@@ -1,15 +1,11 @@
1
- import asyncio
2
  import os
3
  import re
4
  import subprocess
5
- from concurrent.futures import ThreadPoolExecutor
6
- from functools import lru_cache
7
  import tempfile
8
- from typing import Dict, Optional, List, Any, Set
9
 
10
- import hashlib
11
  import requests
12
- import aiohttp
13
  from tenacity import (
14
  retry,
15
  retry_if_exception_type,
@@ -21,10 +17,6 @@ from graphgen.bases import BaseSearcher
21
  from graphgen.utils import logger
22
 
23
 
24
- @lru_cache(maxsize=None)
25
- def _get_pool():
26
- return ThreadPoolExecutor(max_workers=10)
27
-
28
  class RNACentralSearch(BaseSearcher):
29
  """
30
  RNAcentral Search client to search RNA databases.
@@ -35,12 +27,22 @@ class RNACentralSearch(BaseSearcher):
35
  API Documentation: https://rnacentral.org/api/v1
36
  """
37
 
38
- def __init__(self, use_local_blast: bool = False, local_blast_db: str = "rna_db"):
39
- super().__init__()
 
 
 
 
 
 
40
  self.base_url = "https://rnacentral.org/api/v1"
41
  self.headers = {"Accept": "application/json"}
42
  self.use_local_blast = use_local_blast
43
  self.local_blast_db = local_blast_db
 
 
 
 
44
  if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.nhr"):
45
  logger.error("Local BLAST database files not found. Please check the path.")
46
  self.use_local_blast = False
@@ -49,7 +51,7 @@ class RNACentralSearch(BaseSearcher):
49
  def _rna_data_to_dict(
50
  rna_id: str,
51
  rna_data: Dict[str, Any],
52
- xrefs_data: Optional[List[Dict[str, Any]]] = None
53
  ) -> Dict[str, Any]:
54
  organisms, gene_names, so_terms = set(), set(), set()
55
  modifications: List[Any] = []
@@ -58,7 +60,8 @@ class RNACentralSearch(BaseSearcher):
58
  acc = xref.get("accession", {})
59
  if s := acc.get("species"):
60
  organisms.add(s)
61
- if g := acc.get("gene", "").strip():
 
62
  gene_names.add(g)
63
  if m := xref.get("modifications"):
64
  modifications.extend(m)
@@ -137,7 +140,9 @@ class RNACentralSearch(BaseSearcher):
137
  # Normalize sequence
138
  normalized_seq = sequence.replace("U", "T").replace("u", "t").upper()
139
  if not re.fullmatch(r"[ATCGN]+", normalized_seq):
140
- raise ValueError(f"Invalid sequence characters after normalization: {normalized_seq[:50]}...")
 
 
141
 
142
  return hashlib.md5(normalized_seq.encode("ascii")).hexdigest()
143
 
@@ -151,12 +156,21 @@ class RNACentralSearch(BaseSearcher):
151
  url = f"{self.base_url}/rna/{rna_id}"
152
  url += "?flat=true"
153
 
154
- resp = requests.get(url, headers=self.headers, timeout=30)
155
  resp.raise_for_status()
156
 
157
  rna_data = resp.json()
158
  xrefs_data = rna_data.get("xrefs", [])
159
- return self._rna_data_to_dict(rna_id, rna_data, xrefs_data)
 
 
 
 
 
 
 
 
 
160
  except requests.RequestException as e:
161
  logger.error("Network error getting RNA ID %s: %s", rna_id, e)
162
  return None
@@ -164,6 +178,12 @@ class RNACentralSearch(BaseSearcher):
164
  logger.error("Unexpected error getting RNA ID %s: %s", rna_id, e)
165
  return None
166
 
 
 
 
 
 
 
167
  def get_best_hit(self, keyword: str) -> Optional[dict]:
168
  """
169
  Search RNAcentral with a keyword and return the best hit.
@@ -178,7 +198,9 @@ class RNACentralSearch(BaseSearcher):
178
  try:
179
  url = f"{self.base_url}/rna"
180
  params = {"search": keyword, "format": "json"}
181
- resp = requests.get(url, params=params, headers=self.headers, timeout=30)
 
 
182
  resp.raise_for_status()
183
 
184
  data = resp.json()
@@ -206,76 +228,146 @@ class RNACentralSearch(BaseSearcher):
206
  return None
207
 
208
  def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
209
- """Perform local BLAST search using local BLAST database."""
 
 
 
210
  try:
211
- with tempfile.NamedTemporaryFile(mode="w+", suffix=".fa", delete=False) as tmp:
 
 
 
212
  tmp.write(f">query\n{seq}\n")
213
  tmp_name = tmp.name
214
 
 
 
 
 
 
215
  cmd = [
216
- "blastn", "-db", self.local_blast_db, "-query", tmp_name,
217
- "-evalue", str(threshold), "-max_target_seqs", "1", "-outfmt", "6 sacc"
 
 
 
 
 
 
 
 
 
 
 
218
  ]
219
- logger.debug("Running local blastn for RNA: %s", " ".join(cmd))
220
- out = subprocess.check_output(cmd, text=True).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  os.remove(tmp_name)
222
  return out.split("\n", maxsplit=1)[0] if out else None
223
  except Exception as exc:
224
  logger.error("Local blastn failed: %s", exc)
 
 
 
 
 
 
225
  return None
226
 
227
- def get_by_fasta(self, sequence: str, threshold: float = 0.01) -> Optional[dict]:
228
- """
229
- Search RNAcentral with an RNA sequence.
230
- Tries local BLAST first if enabled, falls back to RNAcentral API.
231
- Unified approach: Find RNA ID from sequence search, then call get_by_rna_id() for complete information.
232
- :param sequence: RNA sequence (FASTA format or raw sequence).
233
- :param threshold: E-value threshold for BLAST search.
234
- :return: A dictionary containing complete RNA information or None if not found.
235
- """
236
- def _extract_sequence(sequence: str) -> Optional[str]:
237
- """Extract and normalize RNA sequence from input."""
238
- if sequence.startswith(">"):
239
- seq_lines = sequence.strip().split("\n")
240
- seq = "".join(seq_lines[1:])
241
- else:
242
- seq = sequence.strip().replace(" ", "").replace("\n", "")
243
- return seq if seq and re.fullmatch(r"[AUCGN\s]+", seq, re.I) else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  try:
246
- seq = _extract_sequence(sequence)
247
  if not seq:
248
  logger.error("Empty or invalid RNA sequence provided.")
249
  return None
250
 
251
- # Try local BLAST first if enabled
252
  if self.use_local_blast:
253
- accession = self._local_blast(seq, threshold)
254
- if accession:
255
- logger.debug("Local BLAST found accession: %s", accession)
256
- return self.get_by_rna_id(accession)
257
-
258
- # Fall back to RNAcentral API if local BLAST didn't find result
259
- logger.debug("Falling back to RNAcentral API.")
260
-
261
- md5_hash = self._calculate_md5(seq)
262
- search_url = f"{self.base_url}/rna"
263
- params = {"md5": md5_hash, "format": "json"}
264
-
265
- resp = requests.get(search_url, params=params, headers=self.headers, timeout=60)
266
- resp.raise_for_status()
267
-
268
- search_results = resp.json()
269
- results = search_results.get("results", [])
270
-
271
- if not results:
272
- logger.info("No exact match found in RNAcentral for sequence")
273
- return None
274
- rna_id = results[0].get("rnacentral_id")
275
- if not rna_id:
276
- logger.error("No RNAcentral ID found in search results.")
277
- return None
278
- return self.get_by_rna_id(rna_id)
279
  except Exception as e:
280
  logger.error("Sequence search failed: %s", e)
281
  return None
@@ -283,11 +375,12 @@ class RNACentralSearch(BaseSearcher):
283
  @retry(
284
  stop=stop_after_attempt(3),
285
  wait=wait_exponential(multiplier=1, min=2, max=10),
286
- retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)),
287
  reraise=True,
288
  )
289
- async def search(self, query: str, threshold: float = 0.1, **kwargs) -> Optional[Dict]:
290
  """Search RNAcentral with either an RNAcentral ID, keyword, or RNA sequence."""
 
291
  if not query or not isinstance(query, str):
292
  logger.error("Empty or non-string input.")
293
  return None
@@ -295,19 +388,20 @@ class RNACentralSearch(BaseSearcher):
295
  query = query.strip()
296
  logger.debug("RNAcentral search query: %s", query)
297
 
298
- loop = asyncio.get_running_loop()
299
-
300
- # check if RNA sequence (AUCG characters, contains U)
301
- if query.startswith(">") or (
302
- re.fullmatch(r"[AUCGN\s]+", query, re.I) and "U" in query.upper()
303
- ):
304
- result = await loop.run_in_executor(_get_pool(), self.get_by_fasta, query, threshold)
 
305
  # check if RNAcentral ID (typically starts with URS)
306
  elif re.fullmatch(r"URS\d+", query, re.I):
307
- result = await loop.run_in_executor(_get_pool(), self.get_by_rna_id, query)
308
  else:
309
  # otherwise treat as keyword
310
- result = await loop.run_in_executor(_get_pool(), self.get_best_hit, query)
311
 
312
  if result:
313
  result["_search_query"] = query
 
1
+ import hashlib
2
  import os
3
  import re
4
  import subprocess
 
 
5
  import tempfile
6
+ from typing import Any, Dict, List, Optional, Set
7
 
 
8
  import requests
 
9
  from tenacity import (
10
  retry,
11
  retry_if_exception_type,
 
17
  from graphgen.utils import logger
18
 
19
 
 
 
 
 
20
  class RNACentralSearch(BaseSearcher):
21
  """
22
  RNAcentral Search client to search RNA databases.
 
27
  API Documentation: https://rnacentral.org/api/v1
28
  """
29
 
30
+ def __init__(
31
+ self,
32
+ use_local_blast: bool = False,
33
+ local_blast_db: str = "rna_db",
34
+ api_timeout: int = 30,
35
+ blast_num_threads: int = 4,
36
+ threshold: float = 0.01,
37
+ ):
38
  self.base_url = "https://rnacentral.org/api/v1"
39
  self.headers = {"Accept": "application/json"}
40
  self.use_local_blast = use_local_blast
41
  self.local_blast_db = local_blast_db
42
+ self.api_timeout = api_timeout
43
+ self.blast_num_threads = blast_num_threads # Number of threads for BLAST search
44
+ self.threshold = threshold # E-value threshold for BLAST search
45
+
46
  if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.nhr"):
47
  logger.error("Local BLAST database files not found. Please check the path.")
48
  self.use_local_blast = False
 
51
  def _rna_data_to_dict(
52
  rna_id: str,
53
  rna_data: Dict[str, Any],
54
+ xrefs_data: Optional[List[Dict[str, Any]]] = None,
55
  ) -> Dict[str, Any]:
56
  organisms, gene_names, so_terms = set(), set(), set()
57
  modifications: List[Any] = []
 
60
  acc = xref.get("accession", {})
61
  if s := acc.get("species"):
62
  organisms.add(s)
63
+ gene_value = acc.get("gene")
64
+ if isinstance(gene_value, str) and (g := gene_value.strip()):
65
  gene_names.add(g)
66
  if m := xref.get("modifications"):
67
  modifications.extend(m)
 
140
  # Normalize sequence
141
  normalized_seq = sequence.replace("U", "T").replace("u", "t").upper()
142
  if not re.fullmatch(r"[ATCGN]+", normalized_seq):
143
+ raise ValueError(
144
+ f"Invalid sequence characters after normalization: {normalized_seq[:50]}..."
145
+ )
146
 
147
  return hashlib.md5(normalized_seq.encode("ascii")).hexdigest()
148
 
 
156
  url = f"{self.base_url}/rna/{rna_id}"
157
  url += "?flat=true"
158
 
159
+ resp = requests.get(url, headers=self.headers, timeout=self.api_timeout)
160
  resp.raise_for_status()
161
 
162
  rna_data = resp.json()
163
  xrefs_data = rna_data.get("xrefs", [])
164
+ result = self._rna_data_to_dict(rna_id, rna_data, xrefs_data)
165
+ return result
166
+ except requests.Timeout as e:
167
+ logger.warning(
168
+ "Timeout getting RNA ID %s (timeout=%ds): %s",
169
+ rna_id,
170
+ self.api_timeout,
171
+ e,
172
+ )
173
+ return None
174
  except requests.RequestException as e:
175
  logger.error("Network error getting RNA ID %s: %s", rna_id, e)
176
  return None
 
178
  logger.error("Unexpected error getting RNA ID %s: %s", rna_id, e)
179
  return None
180
 
181
+ @retry(
182
+ stop=stop_after_attempt(3),
183
+ wait=wait_exponential(multiplier=1, min=2, max=10),
184
+ retry=retry_if_exception_type((requests.Timeout, requests.RequestException)),
185
+ reraise=False,
186
+ )
187
  def get_best_hit(self, keyword: str) -> Optional[dict]:
188
  """
189
  Search RNAcentral with a keyword and return the best hit.
 
198
  try:
199
  url = f"{self.base_url}/rna"
200
  params = {"search": keyword, "format": "json"}
201
+ resp = requests.get(
202
+ url, params=params, headers=self.headers, timeout=self.api_timeout
203
+ )
204
  resp.raise_for_status()
205
 
206
  data = resp.json()
 
228
  return None
229
 
230
  def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
231
+ """
232
+ Perform local BLAST search using local BLAST database.
233
+ Optimized with multi-threading and faster output format.
234
+ """
235
  try:
236
+ # Use temporary file for query sequence
237
+ with tempfile.NamedTemporaryFile(
238
+ mode="w+", suffix=".fa", delete=False
239
+ ) as tmp:
240
  tmp.write(f">query\n{seq}\n")
241
  tmp_name = tmp.name
242
 
243
+ # Optimized BLAST command with:
244
+ # - num_threads: Use multiple threads for faster search
245
+ # - outfmt 6 sacc: Only return accession (minimal output)
246
+ # - max_target_seqs 1: Only need the best hit
247
+ # - evalue: Threshold for significance
248
  cmd = [
249
+ "blastn",
250
+ "-db",
251
+ self.local_blast_db,
252
+ "-query",
253
+ tmp_name,
254
+ "-evalue",
255
+ str(threshold),
256
+ "-max_target_seqs",
257
+ "1",
258
+ "-num_threads",
259
+ str(self.blast_num_threads),
260
+ "-outfmt",
261
+ "6 sacc", # Only accession, tab-separated
262
  ]
263
+ logger.debug(
264
+ "Running local blastn for RNA (threads=%d): %s",
265
+ self.blast_num_threads,
266
+ " ".join(cmd),
267
+ )
268
+
269
+ # Run BLAST with timeout to avoid hanging
270
+ try:
271
+ out = subprocess.check_output(
272
+ cmd,
273
+ text=True,
274
+ timeout=300, # 5 minute timeout for BLAST search
275
+ stderr=subprocess.DEVNULL, # Suppress BLAST warnings to reduce I/O
276
+ ).strip()
277
+ except subprocess.TimeoutExpired:
278
+ logger.warning("BLAST search timed out after 5 minutes for sequence")
279
+ os.remove(tmp_name)
280
+ return None
281
+
282
  os.remove(tmp_name)
283
  return out.split("\n", maxsplit=1)[0] if out else None
284
  except Exception as exc:
285
  logger.error("Local blastn failed: %s", exc)
286
+ # Clean up temp file if it still exists
287
+ try:
288
+ if "tmp_name" in locals():
289
+ os.remove(tmp_name)
290
+ except Exception:
291
+ pass
292
  return None
293
 
294
+ @staticmethod
295
+ def _extract_rna_sequence(sequence: str) -> Optional[str]:
296
+ """Extract and normalize RNA sequence from input."""
297
+ if sequence.startswith(">"):
298
+ seq_lines = sequence.strip().split("\n")
299
+ seq = "".join(seq_lines[1:])
300
+ else:
301
+ seq = sequence.strip().replace(" ", "").replace("\n", "")
302
+ # Accept both U (original RNA) and T
303
+ return seq if seq and re.fullmatch(r"[AUCGTN\s]+", seq, re.I) else None
304
+
305
+ def _search_with_local_blast(self, seq: str, threshold: float) -> Optional[dict]:
306
+ """Search using local BLAST database."""
307
+ accession = self._local_blast(seq, threshold)
308
+ if not accession:
309
+ logger.info(
310
+ "Local BLAST found no match for sequence. "
311
+ "API fallback disabled when using local database."
312
+ )
313
+ return None
314
+
315
+ logger.debug("Local BLAST found accession: %s", accession)
316
+ detailed = self.get_by_rna_id(accession)
317
+ if detailed:
318
+ return detailed
319
+ logger.info(
320
+ "Local BLAST found accession %s but could not retrieve metadata from API.",
321
+ accession,
322
+ )
323
+ return None
324
+
325
+ def _search_with_api(self, seq: str) -> Optional[dict]:
326
+ """Search using RNAcentral API with MD5 hash."""
327
+ logger.debug("Falling back to RNAcentral API.")
328
+ md5_hash = self._calculate_md5(seq)
329
+ search_url = f"{self.base_url}/rna"
330
+ params = {"md5": md5_hash, "format": "json"}
331
+
332
+ resp = requests.get(
333
+ search_url, params=params, headers=self.headers, timeout=60
334
+ )
335
+ resp.raise_for_status()
336
+
337
+ search_results = resp.json()
338
+ results = search_results.get("results", [])
339
+
340
+ if not results:
341
+ logger.info("No exact match found in RNAcentral for sequence")
342
+ return None
343
 
344
+ rna_id = results[0].get("rnacentral_id")
345
+ if not rna_id:
346
+ logger.error("No RNAcentral ID found in search results.")
347
+ return None
348
+
349
+ detailed = self.get_by_rna_id(rna_id)
350
+ if detailed:
351
+ return detailed
352
+ # Fallback: use search result data if get_by_rna_id returns None
353
+ logger.debug(
354
+ "Using search result data for %s (get_by_rna_id returned None)", rna_id
355
+ )
356
+ return self._rna_data_to_dict(rna_id, results[0])
357
+
358
+ def get_by_fasta(
359
+ self, sequence: str, threshold: float = 0.01
360
+ ) -> Optional[dict]:
361
+ """Search RNAcentral with an RNA sequence."""
362
  try:
363
+ seq = self._extract_rna_sequence(sequence)
364
  if not seq:
365
  logger.error("Empty or invalid RNA sequence provided.")
366
  return None
367
 
 
368
  if self.use_local_blast:
369
+ return self._search_with_local_blast(seq, threshold)
370
+ return self._search_with_api(seq)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  except Exception as e:
372
  logger.error("Sequence search failed: %s", e)
373
  return None
 
375
  @retry(
376
  stop=stop_after_attempt(3),
377
  wait=wait_exponential(multiplier=1, min=2, max=10),
378
+ retry=retry_if_exception_type((requests.Timeout, requests.RequestException)),
379
  reraise=True,
380
  )
381
+ def search(self, query: str, threshold: float = None, **kwargs) -> Optional[Dict]:
382
  """Search RNAcentral with either an RNAcentral ID, keyword, or RNA sequence."""
383
+ threshold = threshold or self.threshold
384
  if not query or not isinstance(query, str):
385
  logger.error("Empty or non-string input.")
386
  return None
 
388
  query = query.strip()
389
  logger.debug("RNAcentral search query: %s", query)
390
 
391
+ # check if RNA sequence (AUCG or ATCG characters, contains U or T)
392
+ # Note: Sequences with T are also RNA sequences
393
+ is_rna_sequence = query.startswith(">") or (
394
+ re.fullmatch(r"[AUCGTN\s]+", query, re.I)
395
+ and ("U" in query.upper() or "T" in query.upper())
396
+ )
397
+ if is_rna_sequence:
398
+ result = self.get_by_fasta(query, threshold)
399
  # check if RNAcentral ID (typically starts with URS)
400
  elif re.fullmatch(r"URS\d+", query, re.I):
401
+ result = self.get_by_rna_id(query)
402
  else:
403
  # otherwise treat as keyword
404
+ result = self.get_best_hit(query)
405
 
406
  if result:
407
  result["_search_query"] = query
graphgen/models/searcher/db/uniprot_searcher.py CHANGED
@@ -1,10 +1,7 @@
1
- import asyncio
2
  import os
3
  import re
4
  import subprocess
5
  import tempfile
6
- from concurrent.futures import ThreadPoolExecutor
7
- from functools import lru_cache
8
  from io import StringIO
9
  from typing import Dict, Optional
10
 
@@ -22,15 +19,6 @@ from graphgen.bases import BaseSearcher
22
  from graphgen.utils import logger
23
 
24
 
25
- @lru_cache(maxsize=None)
26
- def _get_pool():
27
- return ThreadPoolExecutor(max_workers=10)
28
-
29
-
30
- # ensure only one BLAST searcher at a time
31
- _blast_lock = asyncio.Lock()
32
-
33
-
34
  class UniProtSearch(BaseSearcher):
35
  """
36
  UniProt Search client to searcher with UniProt.
@@ -39,10 +27,18 @@ class UniProtSearch(BaseSearcher):
39
  3) Search with FASTA sequence (BLAST searcher). Note that NCBIWWW does not support async.
40
  """
41
 
42
- def __init__(self, use_local_blast: bool = False, local_blast_db: str = "sp_db"):
43
- super().__init__()
 
 
 
 
 
44
  self.use_local_blast = use_local_blast
45
  self.local_blast_db = local_blast_db
 
 
 
46
  if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.phr"):
47
  logger.error("Local BLAST database files not found. Please check the path.")
48
  self.use_local_blast = False
@@ -61,7 +57,7 @@ class UniProtSearch(BaseSearcher):
61
 
62
  @staticmethod
63
  def _swissprot_to_dict(record: SwissProt.Record) -> dict:
64
- """error
65
  Convert a SwissProt.Record to a dictionary.
66
  """
67
  functions = []
@@ -104,75 +100,88 @@ class UniProtSearch(BaseSearcher):
104
  logger.error("Keyword %s not found: %s", keyword, e)
105
  return None
106
 
107
- def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
 
108
  """
109
- Search UniProt with a FASTA sequence and return the best hit.
110
  :param fasta_sequence: The FASTA sequence.
111
- :param threshold: E-value threshold for BLAST searcher.
112
- :return: A dictionary containing the best hit information or None if not found.
113
  """
114
  try:
115
  if fasta_sequence.startswith(">"):
116
  seq = str(list(SeqIO.parse(StringIO(fasta_sequence), "fasta"))[0].seq)
117
  else:
118
  seq = fasta_sequence.strip()
 
119
  except Exception as e: # pylint: disable=broad-except
120
  logger.error("Invalid FASTA sequence: %s", e)
121
  return None
122
 
123
- if not seq:
124
- logger.error("Empty FASTA sequence provided.")
 
 
 
 
 
 
125
  return None
 
 
126
 
127
- accession = None
128
- if self.use_local_blast:
129
- accession = self._local_blast(seq, threshold)
130
- if accession:
131
- logger.debug("Local BLAST found accession: %s", accession)
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- if not accession:
134
- logger.debug("Falling back to NCBIWWW.qblast.")
 
135
 
136
- # UniProtKB/Swiss-Prot BLAST API
137
- try:
138
- logger.debug(
139
- "Performing BLAST searcher for the given sequence: %s", seq
140
- )
141
- result_handle = NCBIWWW.qblast(
142
- program="blastp",
143
- database="swissprot",
144
- sequence=seq,
145
- hitlist_size=1,
146
- expect=threshold,
147
- )
148
- blast_record = NCBIXML.read(result_handle)
149
- except RequestException:
150
- raise
151
- except Exception as e: # pylint: disable=broad-except
152
- logger.error("BLAST searcher failed: %s", e)
153
- return None
154
 
155
- if not blast_record.alignments:
156
- logger.info("No BLAST hits found for the given sequence.")
157
- return None
 
158
 
159
- best_alignment = blast_record.alignments[0]
160
- best_hsp = best_alignment.hsps[0]
161
- if best_hsp.expect > threshold:
162
- logger.info("No BLAST hits below the threshold E-value.")
163
- return None
164
- hit_id = best_alignment.hit_id
 
 
165
 
166
- # like sp|P01308.1|INS_HUMAN
167
- accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id
168
- return self.get_by_accession(accession)
 
 
169
 
170
  def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
171
  """
172
  Perform local BLAST search using local BLAST database.
173
- :param seq: The protein sequence.
174
- :param threshold: E-value threshold for BLAST searcher.
175
- :return: The accession number of the best hit or None if not found.
176
  """
177
  try:
178
  with tempfile.NamedTemporaryFile(
@@ -181,6 +190,11 @@ class UniProtSearch(BaseSearcher):
181
  tmp.write(f">query\n{seq}\n")
182
  tmp_name = tmp.name
183
 
 
 
 
 
 
184
  cmd = [
185
  "blastp",
186
  "-db",
@@ -191,11 +205,30 @@ class UniProtSearch(BaseSearcher):
191
  str(threshold),
192
  "-max_target_seqs",
193
  "1",
 
 
194
  "-outfmt",
195
- "6 sacc", # only return accession
196
  ]
197
- logger.debug("Running local blastp: %s", " ".join(cmd))
198
- out = subprocess.check_output(cmd, text=True).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  os.remove(tmp_name)
200
  if out:
201
  return out.split("\n", maxsplit=1)[0]
@@ -210,16 +243,14 @@ class UniProtSearch(BaseSearcher):
210
  retry=retry_if_exception_type(RequestException),
211
  reraise=True,
212
  )
213
- async def search(
214
- self, query: str, threshold: float = 0.7, **kwargs
215
- ) -> Optional[Dict]:
216
  """
217
  Search UniProt with either an accession number, keyword, or FASTA sequence.
218
  :param query: The searcher query (accession number, keyword, or FASTA sequence).
219
  :param threshold: E-value threshold for BLAST searcher.
220
  :return: A dictionary containing the best hit information or None if not found.
221
  """
222
-
223
  # auto detect query type
224
  if not query or not isinstance(query, str):
225
  logger.error("Empty or non-string input.")
@@ -228,26 +259,21 @@ class UniProtSearch(BaseSearcher):
228
 
229
  logger.debug("UniProt searcher query: %s", query)
230
 
231
- loop = asyncio.get_running_loop()
232
-
233
  # check if fasta sequence
234
  if query.startswith(">") or re.fullmatch(
235
  r"[ACDEFGHIKLMNPQRSTVWY\s]+", query, re.I
236
  ):
237
- async with _blast_lock:
238
- result = await loop.run_in_executor(
239
- _get_pool(), self.get_by_fasta, query, threshold
240
- )
241
 
242
  # check if accession number
243
- elif re.fullmatch(r"[A-NR-Z0-9]{6,10}", query, re.I):
244
- result = await loop.run_in_executor(
245
- _get_pool(), self.get_by_accession, query
246
- )
247
 
248
  else:
249
  # otherwise treat as keyword
250
- result = await loop.run_in_executor(_get_pool(), self.get_best_hit, query)
251
 
252
  if result:
253
  result["_search_query"] = query
 
 
1
  import os
2
  import re
3
  import subprocess
4
  import tempfile
 
 
5
  from io import StringIO
6
  from typing import Dict, Optional
7
 
 
19
  from graphgen.utils import logger
20
 
21
 
 
 
 
 
 
 
 
 
 
22
  class UniProtSearch(BaseSearcher):
23
  """
24
  UniProt Search client to searcher with UniProt.
 
27
  3) Search with FASTA sequence (BLAST searcher). Note that NCBIWWW does not support async.
28
  """
29
 
30
+ def __init__(
31
+ self,
32
+ use_local_blast: bool = False,
33
+ local_blast_db: str = "sp_db",
34
+ blast_num_threads: int = 4,
35
+ threshold: float = 0.01,
36
+ ):
37
  self.use_local_blast = use_local_blast
38
  self.local_blast_db = local_blast_db
39
+ self.blast_num_threads = blast_num_threads # Number of threads for BLAST search
40
+ self.threshold = threshold
41
+
42
  if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.phr"):
43
  logger.error("Local BLAST database files not found. Please check the path.")
44
  self.use_local_blast = False
 
57
 
58
  @staticmethod
59
  def _swissprot_to_dict(record: SwissProt.Record) -> dict:
60
+ """
61
  Convert a SwissProt.Record to a dictionary.
62
  """
63
  functions = []
 
100
  logger.error("Keyword %s not found: %s", keyword, e)
101
  return None
102
 
103
+
104
+ def _parse_fasta_sequence(self, fasta_sequence: str) -> Optional[str]:
105
  """
106
+ Parse and extract sequence from FASTA format.
107
  :param fasta_sequence: The FASTA sequence.
108
+ :return: Extracted sequence string or None if invalid.
 
109
  """
110
  try:
111
  if fasta_sequence.startswith(">"):
112
  seq = str(list(SeqIO.parse(StringIO(fasta_sequence), "fasta"))[0].seq)
113
  else:
114
  seq = fasta_sequence.strip()
115
+ return seq if seq else None
116
  except Exception as e: # pylint: disable=broad-except
117
  logger.error("Invalid FASTA sequence: %s", e)
118
  return None
119
 
120
+ def _search_with_local_blast(self, seq: str, threshold: float) -> Optional[Dict]:
121
+ """Search using local BLAST database."""
122
+ accession = self._local_blast(seq, threshold)
123
+ if not accession:
124
+ logger.info(
125
+ "Local BLAST found no match for sequence. "
126
+ "API fallback disabled when using local database."
127
+ )
128
  return None
129
+ logger.debug("Local BLAST found accession: %s", accession)
130
+ return self.get_by_accession(accession)
131
 
132
+ def _search_with_network_blast(self, seq: str, threshold: float) -> Optional[Dict]:
133
+ """Search using network BLAST (NCBIWWW)."""
134
+ logger.debug("Falling back to NCBIWWW.qblast.")
135
+ try:
136
+ logger.debug("Performing BLAST searcher for the given sequence: %s", seq)
137
+ result_handle = NCBIWWW.qblast(
138
+ program="blastp",
139
+ database="swissprot",
140
+ sequence=seq,
141
+ hitlist_size=1,
142
+ expect=threshold,
143
+ )
144
+ blast_record = NCBIXML.read(result_handle)
145
+ except RequestException:
146
+ raise
147
+ except Exception as e: # pylint: disable=broad-except
148
+ logger.error("BLAST searcher failed: %s", e)
149
+ return None
150
 
151
+ if not blast_record.alignments:
152
+ logger.info("No BLAST hits found for the given sequence.")
153
+ return None
154
 
155
+ best_alignment = blast_record.alignments[0]
156
+ best_hsp = best_alignment.hsps[0]
157
+ if best_hsp.expect > threshold:
158
+ logger.info("No BLAST hits below the threshold E-value.")
159
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ # like sp|P01308.1|INS_HUMAN
162
+ hit_id = best_alignment.hit_id
163
+ accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id
164
+ return self.get_by_accession(accession)
165
 
166
+ def get_by_fasta(
167
+ self, fasta_sequence: str, threshold: float
168
+ ) -> Optional[Dict]:
169
+ """Search UniProt with a FASTA sequence and return the best hit."""
170
+ seq = self._parse_fasta_sequence(fasta_sequence)
171
+ if not seq:
172
+ logger.error("Empty FASTA sequence provided.")
173
+ return None
174
 
175
+ search_method = (
176
+ self._search_with_local_blast if self.use_local_blast
177
+ else self._search_with_network_blast
178
+ )
179
+ return search_method(seq, threshold)
180
 
181
  def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
182
  """
183
  Perform local BLAST search using local BLAST database.
184
+ Optimized with multi-threading and faster output format.
 
 
185
  """
186
  try:
187
  with tempfile.NamedTemporaryFile(
 
190
  tmp.write(f">query\n{seq}\n")
191
  tmp_name = tmp.name
192
 
193
+ # Optimized BLAST command with:
194
+ # - num_threads: Use multiple threads for faster search
195
+ # - outfmt 6 sacc: Only return accession (minimal output)
196
+ # - max_target_seqs 1: Only need the best hit
197
+ # - evalue: Threshold for significance
198
  cmd = [
199
  "blastp",
200
  "-db",
 
205
  str(threshold),
206
  "-max_target_seqs",
207
  "1",
208
+ "-num_threads",
209
+ str(self.blast_num_threads),
210
  "-outfmt",
211
+ "6 sacc", # Only accession, tab-separated
212
  ]
213
+ logger.debug(
214
+ "Running local blastp (threads=%d): %s",
215
+ self.blast_num_threads,
216
+ " ".join(cmd),
217
+ )
218
+
219
+ # Run BLAST with timeout to avoid hanging
220
+ try:
221
+ out = subprocess.check_output(
222
+ cmd,
223
+ text=True,
224
+ timeout=300, # 5 minute timeout for BLAST search
225
+ stderr=subprocess.DEVNULL, # Suppress BLAST warnings to reduce I/O
226
+ ).strip()
227
+ except subprocess.TimeoutExpired:
228
+ logger.warning("BLAST search timed out after 5 minutes for sequence")
229
+ os.remove(tmp_name)
230
+ return None
231
+
232
  os.remove(tmp_name)
233
  if out:
234
  return out.split("\n", maxsplit=1)[0]
 
243
  retry=retry_if_exception_type(RequestException),
244
  reraise=True,
245
  )
246
+ def search(self, query: str, threshold: float = None, **kwargs) -> Optional[Dict]:
 
 
247
  """
248
  Search UniProt with either an accession number, keyword, or FASTA sequence.
249
  :param query: The searcher query (accession number, keyword, or FASTA sequence).
250
  :param threshold: E-value threshold for BLAST searcher.
251
  :return: A dictionary containing the best hit information or None if not found.
252
  """
253
+ threshold = threshold or self.threshold
254
  # auto detect query type
255
  if not query or not isinstance(query, str):
256
  logger.error("Empty or non-string input.")
 
259
 
260
  logger.debug("UniProt searcher query: %s", query)
261
 
 
 
262
  # check if fasta sequence
263
  if query.startswith(">") or re.fullmatch(
264
  r"[ACDEFGHIKLMNPQRSTVWY\s]+", query, re.I
265
  ):
266
+ result = self.get_by_fasta(query, threshold)
 
 
 
267
 
268
  # check if accession number
269
+ # UniProt accession IDs: 6-10 characters, must start with a letter
270
+ # Format: [A-Z][A-Z0-9]{5,9} (6-10 chars total: 1 letter + 5-9 alphanumeric)
271
+ elif re.fullmatch(r"[A-Z][A-Z0-9]{5,9}", query, re.I):
272
+ result = self.get_by_accession(query)
273
 
274
  else:
275
  # otherwise treat as keyword
276
+ result = self.get_best_hit(query)
277
 
278
  if result:
279
  result["_search_query"] = query
graphgen/models/searcher/web/bing_search.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  import requests
2
  from fastapi import HTTPException
3
 
 
1
+ """
2
+ To use Bing Web Search API,
3
+ follow the instructions [here](https://www.microsoft.com/en-us/bing/apis/bing-web-search-api)
4
+ and obtain your Bing subscription key.
5
+ """
6
+
7
  import requests
8
  from fastapi import HTTPException
9
 
graphgen/models/searcher/web/google_search.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  import requests
2
  from fastapi import HTTPException
3
 
 
1
+ """
2
+ To use Google Web Search API,
3
+ follow the instructions [here](https://developers.google.com/custom-search/v1/overview)
4
+ to get your Google searcher api key.
5
+ """
6
+
7
  import requests
8
  from fastapi import HTTPException
9
 
graphgen/operators/__init__.py CHANGED
@@ -6,7 +6,7 @@ from .judge import JudgeService
6
  from .partition import PartitionService
7
  from .quiz import QuizService
8
  from .read import read
9
- from .search import search_all
10
 
11
  operators = {
12
  "read": read,
@@ -15,7 +15,7 @@ operators = {
15
  "quiz": QuizService,
16
  "judge": JudgeService,
17
  "extract": ExtractService,
18
- "search": search_all,
19
  "partition": PartitionService,
20
  "generate": GenerateService,
21
  }
 
6
  from .partition import PartitionService
7
  from .quiz import QuizService
8
  from .read import read
9
+ from .search import SearchService
10
 
11
  operators = {
12
  "read": read,
 
15
  "quiz": QuizService,
16
  "judge": JudgeService,
17
  "extract": ExtractService,
18
+ "search": SearchService,
19
  "partition": PartitionService,
20
  "generate": GenerateService,
21
  }
graphgen/operators/search/__init__.py CHANGED
@@ -1 +1 @@
1
- from .search_all import search_all
 
1
+ from .search_service import SearchService
graphgen/operators/search/search_all.py DELETED
@@ -1,83 +0,0 @@
1
- """
2
- To use Google Web Search API,
3
- follow the instructions [here](https://developers.google.com/custom-search/v1/overview)
4
- to get your Google searcher api key.
5
-
6
- To use Bing Web Search API,
7
- follow the instructions [here](https://www.microsoft.com/en-us/bing/apis/bing-web-search-api)
8
- and obtain your Bing subscription key.
9
- """
10
-
11
-
12
- from graphgen.utils import logger, run_concurrent
13
-
14
-
15
- async def search_all(
16
- seed_data: dict,
17
- search_config: dict,
18
- ) -> dict:
19
- """
20
- Perform searches across multiple search types and aggregate the results.
21
- :param seed_data: A dictionary containing seed data with entity names.
22
- :param search_config: A dictionary specifying which data sources to use for searching.
23
- :return: A dictionary with
24
- """
25
-
26
- results = {}
27
- data_sources = search_config.get("data_sources", [])
28
-
29
- for data_source in data_sources:
30
- data = list(seed_data.values())
31
- data = [d["content"] for d in data if "content" in d]
32
- data = list(set(data)) # Remove duplicates
33
-
34
- if data_source == "uniprot":
35
- from graphgen.models import UniProtSearch
36
-
37
- uniprot_search_client = UniProtSearch(
38
- **search_config.get("uniprot_params", {})
39
- )
40
-
41
- uniprot_results = await run_concurrent(
42
- uniprot_search_client.search,
43
- data,
44
- desc="Searching UniProt database",
45
- unit="keyword",
46
- )
47
- results[data_source] = uniprot_results
48
-
49
- elif data_source == "ncbi":
50
- from graphgen.models import NCBISearch
51
-
52
- ncbi_search_client = NCBISearch(
53
- **search_config.get("ncbi_params", {})
54
- )
55
-
56
- ncbi_results = await run_concurrent(
57
- ncbi_search_client.search,
58
- data,
59
- desc="Searching NCBI database",
60
- unit="keyword",
61
- )
62
- results[data_source] = ncbi_results
63
-
64
- elif data_source == "rnacentral":
65
- from graphgen.models import RNACentralSearch
66
-
67
- rnacentral_search_client = RNACentralSearch(
68
- **search_config.get("rnacentral_params", {})
69
- )
70
-
71
- rnacentral_results = await run_concurrent(
72
- rnacentral_search_client.search,
73
- data,
74
- desc="Searching RNAcentral database",
75
- unit="keyword",
76
- )
77
- results[data_source] = rnacentral_results
78
-
79
- else:
80
- logger.error("Data source %s not supported.", data_source)
81
- continue
82
-
83
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/operators/search/search_service.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Optional
3
+
4
+ import pandas as pd
5
+
6
+ from graphgen.bases import BaseOperator
7
+ from graphgen.common import init_storage
8
+ from graphgen.utils import compute_content_hash, logger, run_concurrent
9
+
10
+
11
+ class SearchService(BaseOperator):
12
+ """
13
+ Service class for performing searches across multiple data sources.
14
+ Provides search functionality for UniProt, NCBI, and RNAcentral databases.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ working_dir: str = "cache",
20
+ kv_backend: str = "rocksdb",
21
+ data_sources: list = None,
22
+ **kwargs,
23
+ ):
24
+ super().__init__(working_dir=working_dir, op_name="search_service")
25
+ self.working_dir = working_dir
26
+ self.data_sources = data_sources or []
27
+ self.kwargs = kwargs
28
+ self.search_storage = init_storage(
29
+ backend=kv_backend, working_dir=working_dir, namespace="search"
30
+ )
31
+ self.searchers = {}
32
+
33
+ def _init_searchers(self):
34
+ """
35
+ Initialize all searchers (deferred import to avoid circular imports).
36
+ """
37
+ for datasource in self.data_sources:
38
+ if datasource in self.searchers:
39
+ continue
40
+ if datasource == "uniprot":
41
+ from graphgen.models import UniProtSearch
42
+
43
+ params = self.kwargs.get("uniprot_params", {})
44
+ self.searchers[datasource] = UniProtSearch(**params)
45
+ elif datasource == "ncbi":
46
+ from graphgen.models import NCBISearch
47
+
48
+ params = self.kwargs.get("ncbi_params", {})
49
+ self.searchers[datasource] = NCBISearch(**params)
50
+ elif datasource == "rnacentral":
51
+ from graphgen.models import RNACentralSearch
52
+
53
+ params = self.kwargs.get("rnacentral_params", {})
54
+ self.searchers[datasource] = RNACentralSearch(**params)
55
+ else:
56
+ logger.error(f"Unknown data source: {datasource}, skipping")
57
+
58
+ @staticmethod
59
+ async def _perform_search(
60
+ seed: dict, searcher_obj, data_source: str
61
+ ) -> Optional[dict]:
62
+ """
63
+ Perform search for a single seed using the specified searcher.
64
+
65
+ :param seed: The seed document with 'content' field
66
+ :param searcher_obj: The searcher instance
67
+ :param data_source: The data source name
68
+ :return: Search result with metadata
69
+ """
70
+ query = seed.get("content", "")
71
+
72
+ if not query:
73
+ logger.warning("Empty query for seed: %s", seed)
74
+ return None
75
+
76
+ result = searcher_obj.search(query)
77
+ if result:
78
+ result["_doc_id"] = compute_content_hash(str(data_source) + query, "doc-")
79
+ result["data_source"] = data_source
80
+ result["type"] = seed.get("type", "text")
81
+
82
+ return result
83
+
84
+ def _process_single_source(
85
+ self, data_source: str, seed_data: list[dict]
86
+ ) -> list[dict]:
87
+ """
88
+ process a single data source: check cache, search missing, update cache.
89
+ """
90
+ searcher = self.searchers[data_source]
91
+
92
+ seeds_with_ids = []
93
+ for seed in seed_data:
94
+ query = seed.get("content", "")
95
+ if not query:
96
+ continue
97
+ doc_id = compute_content_hash(str(data_source) + query, "doc-")
98
+ seeds_with_ids.append((doc_id, seed))
99
+
100
+ if not seeds_with_ids:
101
+ return []
102
+
103
+ doc_ids = [doc_id for doc_id, _ in seeds_with_ids]
104
+ cached_results = self.search_storage.get_by_ids(doc_ids)
105
+
106
+ to_search_seeds = []
107
+ final_results = []
108
+
109
+ for (doc_id, seed), cached in zip(seeds_with_ids, cached_results):
110
+ if cached is not None:
111
+ if "_doc_id" not in cached:
112
+ cached["_doc_id"] = doc_id
113
+ final_results.append(cached)
114
+ else:
115
+ to_search_seeds.append(seed)
116
+
117
+ if to_search_seeds:
118
+ new_results = run_concurrent(
119
+ partial(
120
+ self._perform_search, searcher_obj=searcher, data_source=data_source
121
+ ),
122
+ to_search_seeds,
123
+ desc=f"Searching {data_source} database",
124
+ unit="keyword",
125
+ )
126
+ new_results = [res for res in new_results if res is not None]
127
+
128
+ if new_results:
129
+ upsert_data = {res["_doc_id"]: res for res in new_results}
130
+ self.search_storage.upsert(upsert_data)
131
+ logger.info(
132
+ f"Saved {len(upsert_data)} new results to {data_source} cache"
133
+ )
134
+
135
+ final_results.extend(new_results)
136
+
137
+ return final_results
138
+
139
+ def process(self, batch: pd.DataFrame) -> pd.DataFrame:
140
+ docs = batch.to_dict(orient="records")
141
+
142
+ self._init_searchers()
143
+
144
+ seed_data = [doc for doc in docs if doc and "content" in doc]
145
+
146
+ if not seed_data:
147
+ logger.warning("No valid seeds in batch")
148
+ return pd.DataFrame([])
149
+
150
+ all_results = []
151
+
152
+ for data_source in self.data_sources:
153
+ if data_source not in self.searchers:
154
+ logger.error(f"Data source {data_source} not initialized, skipping")
155
+ continue
156
+
157
+ source_results = self._process_single_source(data_source, seed_data)
158
+ all_results.extend(source_results)
159
+
160
+ if not all_results:
161
+ logger.warning("No search results generated for this batch")
162
+
163
+ return pd.DataFrame(all_results)