GraphGen / graphgen /operators /search /search_service.py
github-actions[bot]
Auto-sync from demo at Thu Dec 25 14:16:44 UTC 2025
31df32c
from functools import partial
from typing import Optional
import pandas as pd
from graphgen.bases import BaseOperator
from graphgen.common import init_storage
from graphgen.utils import compute_content_hash, logger, run_concurrent
class SearchService(BaseOperator):
"""
Service class for performing searches across multiple data sources.
Provides search functionality for UniProt, NCBI, and RNAcentral databases.
"""
def __init__(
self,
working_dir: str = "cache",
kv_backend: str = "rocksdb",
data_sources: list = None,
**kwargs,
):
super().__init__(working_dir=working_dir, op_name="search_service")
self.working_dir = working_dir
self.data_sources = data_sources or []
self.kwargs = kwargs
self.search_storage = init_storage(
backend=kv_backend, working_dir=working_dir, namespace="search"
)
self.searchers = {}
def _init_searchers(self):
"""
Initialize all searchers (deferred import to avoid circular imports).
"""
for datasource in self.data_sources:
if datasource in self.searchers:
continue
if datasource == "uniprot":
from graphgen.models import UniProtSearch
params = self.kwargs.get("uniprot_params", {})
self.searchers[datasource] = UniProtSearch(**params)
elif datasource == "ncbi":
from graphgen.models import NCBISearch
params = self.kwargs.get("ncbi_params", {})
self.searchers[datasource] = NCBISearch(**params)
elif datasource == "rnacentral":
from graphgen.models import RNACentralSearch
params = self.kwargs.get("rnacentral_params", {})
self.searchers[datasource] = RNACentralSearch(**params)
else:
logger.error(f"Unknown data source: {datasource}, skipping")
@staticmethod
async def _perform_search(
seed: dict, searcher_obj, data_source: str
) -> Optional[dict]:
"""
Perform search for a single seed using the specified searcher.
:param seed: The seed document with 'content' field
:param searcher_obj: The searcher instance
:param data_source: The data source name
:return: Search result with metadata
"""
query = seed.get("content", "")
if not query:
logger.warning("Empty query for seed: %s", seed)
return None
result = searcher_obj.search(query)
if result:
result["_doc_id"] = compute_content_hash(str(data_source) + query, "doc-")
result["data_source"] = data_source
result["type"] = seed.get("type", "text")
return result
def _process_single_source(
self, data_source: str, seed_data: list[dict]
) -> list[dict]:
"""
process a single data source: check cache, search missing, update cache.
"""
searcher = self.searchers[data_source]
seeds_with_ids = []
for seed in seed_data:
query = seed.get("content", "")
if not query:
continue
doc_id = compute_content_hash(str(data_source) + query, "doc-")
seeds_with_ids.append((doc_id, seed))
if not seeds_with_ids:
return []
doc_ids = [doc_id for doc_id, _ in seeds_with_ids]
cached_results = self.search_storage.get_by_ids(doc_ids)
to_search_seeds = []
final_results = []
for (doc_id, seed), cached in zip(seeds_with_ids, cached_results):
if cached is not None:
if "_doc_id" not in cached:
cached["_doc_id"] = doc_id
final_results.append(cached)
else:
to_search_seeds.append(seed)
if to_search_seeds:
new_results = run_concurrent(
partial(
self._perform_search, searcher_obj=searcher, data_source=data_source
),
to_search_seeds,
desc=f"Searching {data_source} database",
unit="keyword",
)
new_results = [res for res in new_results if res is not None]
if new_results:
upsert_data = {res["_doc_id"]: res for res in new_results}
self.search_storage.upsert(upsert_data)
logger.info(
f"Saved {len(upsert_data)} new results to {data_source} cache"
)
final_results.extend(new_results)
return final_results
def process(self, batch: pd.DataFrame) -> pd.DataFrame:
docs = batch.to_dict(orient="records")
self._init_searchers()
seed_data = [doc for doc in docs if doc and "content" in doc]
if not seed_data:
logger.warning("No valid seeds in batch")
return pd.DataFrame([])
all_results = []
for data_source in self.data_sources:
if data_source not in self.searchers:
logger.error(f"Data source {data_source} not initialized, skipping")
continue
source_results = self._process_single_source(data_source, seed_data)
all_results.extend(source_results)
if not all_results:
logger.warning("No search results generated for this batch")
return pd.DataFrame(all_results)