Spaces:
Running
Running
| from functools import partial | |
| from typing import Optional | |
| import pandas as pd | |
| from graphgen.bases import BaseOperator | |
| from graphgen.common import init_storage | |
| from graphgen.utils import compute_content_hash, logger, run_concurrent | |
| class SearchService(BaseOperator): | |
| """ | |
| Service class for performing searches across multiple data sources. | |
| Provides search functionality for UniProt, NCBI, and RNAcentral databases. | |
| """ | |
| def __init__( | |
| self, | |
| working_dir: str = "cache", | |
| kv_backend: str = "rocksdb", | |
| data_sources: list = None, | |
| **kwargs, | |
| ): | |
| super().__init__(working_dir=working_dir, op_name="search_service") | |
| self.working_dir = working_dir | |
| self.data_sources = data_sources or [] | |
| self.kwargs = kwargs | |
| self.search_storage = init_storage( | |
| backend=kv_backend, working_dir=working_dir, namespace="search" | |
| ) | |
| self.searchers = {} | |
| def _init_searchers(self): | |
| """ | |
| Initialize all searchers (deferred import to avoid circular imports). | |
| """ | |
| for datasource in self.data_sources: | |
| if datasource in self.searchers: | |
| continue | |
| if datasource == "uniprot": | |
| from graphgen.models import UniProtSearch | |
| params = self.kwargs.get("uniprot_params", {}) | |
| self.searchers[datasource] = UniProtSearch(**params) | |
| elif datasource == "ncbi": | |
| from graphgen.models import NCBISearch | |
| params = self.kwargs.get("ncbi_params", {}) | |
| self.searchers[datasource] = NCBISearch(**params) | |
| elif datasource == "rnacentral": | |
| from graphgen.models import RNACentralSearch | |
| params = self.kwargs.get("rnacentral_params", {}) | |
| self.searchers[datasource] = RNACentralSearch(**params) | |
| else: | |
| logger.error(f"Unknown data source: {datasource}, skipping") | |
| async def _perform_search( | |
| seed: dict, searcher_obj, data_source: str | |
| ) -> Optional[dict]: | |
| """ | |
| Perform search for a single seed using the specified searcher. | |
| :param seed: The seed document with 'content' field | |
| :param searcher_obj: The searcher instance | |
| :param data_source: The data source name | |
| :return: Search result with metadata | |
| """ | |
| query = seed.get("content", "") | |
| if not query: | |
| logger.warning("Empty query for seed: %s", seed) | |
| return None | |
| result = searcher_obj.search(query) | |
| if result: | |
| result["_doc_id"] = compute_content_hash(str(data_source) + query, "doc-") | |
| result["data_source"] = data_source | |
| result["type"] = seed.get("type", "text") | |
| return result | |
| def _process_single_source( | |
| self, data_source: str, seed_data: list[dict] | |
| ) -> list[dict]: | |
| """ | |
| process a single data source: check cache, search missing, update cache. | |
| """ | |
| searcher = self.searchers[data_source] | |
| seeds_with_ids = [] | |
| for seed in seed_data: | |
| query = seed.get("content", "") | |
| if not query: | |
| continue | |
| doc_id = compute_content_hash(str(data_source) + query, "doc-") | |
| seeds_with_ids.append((doc_id, seed)) | |
| if not seeds_with_ids: | |
| return [] | |
| doc_ids = [doc_id for doc_id, _ in seeds_with_ids] | |
| cached_results = self.search_storage.get_by_ids(doc_ids) | |
| to_search_seeds = [] | |
| final_results = [] | |
| for (doc_id, seed), cached in zip(seeds_with_ids, cached_results): | |
| if cached is not None: | |
| if "_doc_id" not in cached: | |
| cached["_doc_id"] = doc_id | |
| final_results.append(cached) | |
| else: | |
| to_search_seeds.append(seed) | |
| if to_search_seeds: | |
| new_results = run_concurrent( | |
| partial( | |
| self._perform_search, searcher_obj=searcher, data_source=data_source | |
| ), | |
| to_search_seeds, | |
| desc=f"Searching {data_source} database", | |
| unit="keyword", | |
| ) | |
| new_results = [res for res in new_results if res is not None] | |
| if new_results: | |
| upsert_data = {res["_doc_id"]: res for res in new_results} | |
| self.search_storage.upsert(upsert_data) | |
| logger.info( | |
| f"Saved {len(upsert_data)} new results to {data_source} cache" | |
| ) | |
| final_results.extend(new_results) | |
| return final_results | |
| def process(self, batch: pd.DataFrame) -> pd.DataFrame: | |
| docs = batch.to_dict(orient="records") | |
| self._init_searchers() | |
| seed_data = [doc for doc in docs if doc and "content" in doc] | |
| if not seed_data: | |
| logger.warning("No valid seeds in batch") | |
| return pd.DataFrame([]) | |
| all_results = [] | |
| for data_source in self.data_sources: | |
| if data_source not in self.searchers: | |
| logger.error(f"Data source {data_source} not initialized, skipping") | |
| continue | |
| source_results = self._process_single_source(data_source, seed_data) | |
| all_results.extend(source_results) | |
| if not all_results: | |
| logger.warning("No search results generated for this batch") | |
| return pd.DataFrame(all_results) | |