github-actions[bot] commited on
Commit
36059f0
·
1 Parent(s): bf1f6d4

Auto-sync from demo at Wed Feb 4 14:58:57 UTC 2026

Browse files
graphgen/operators/search/search_service.py CHANGED
@@ -1,9 +1,9 @@
1
  from functools import partial
2
- from typing import TYPE_CHECKING, Optional
3
 
4
  from graphgen.bases import BaseOperator
5
  from graphgen.common.init_storage import init_storage
6
- from graphgen.utils import compute_content_hash, logger, run_concurrent
7
 
8
  if TYPE_CHECKING:
9
  import pandas as pd
@@ -19,42 +19,47 @@ class SearchService(BaseOperator):
19
  self,
20
  working_dir: str = "cache",
21
  kv_backend: str = "rocksdb",
22
- data_sources: list = None,
23
  **kwargs,
24
  ):
25
- super().__init__(working_dir=working_dir, op_name="search_service")
26
- self.working_dir = working_dir
27
- self.data_sources = data_sources or []
 
28
  self.kwargs = kwargs
29
  self.search_storage = init_storage(
30
  backend=kv_backend, working_dir=working_dir, namespace="search"
31
  )
32
- self.searchers = {}
33
 
34
- def _init_searchers(self):
35
  """
36
- Initialize all searchers (deferred import to avoid circular imports).
37
  """
38
- for datasource in self.data_sources:
39
- if datasource in self.searchers:
40
- continue
41
- if datasource == "uniprot":
42
- from graphgen.models import UniProtSearch
 
43
 
44
- params = self.kwargs.get("uniprot_params", {})
45
- self.searchers[datasource] = UniProtSearch(**params)
46
- elif datasource == "ncbi":
47
- from graphgen.models import NCBISearch
48
 
49
- params = self.kwargs.get("ncbi_params", {})
50
- self.searchers[datasource] = NCBISearch(**params)
51
- elif datasource == "rnacentral":
52
- from graphgen.models import RNACentralSearch
53
 
54
- params = self.kwargs.get("rnacentral_params", {})
55
- self.searchers[datasource] = RNACentralSearch(**params)
56
- else:
57
- logger.error(f"Unknown data source: {datasource}, skipping")
 
 
 
 
 
58
 
59
  @staticmethod
60
  async def _perform_search(
@@ -76,91 +81,59 @@ class SearchService(BaseOperator):
76
 
77
  result = searcher_obj.search(query)
78
  if result:
79
- result["_doc_id"] = compute_content_hash(str(data_source) + query, "doc-")
80
  result["data_source"] = data_source
81
  result["type"] = seed.get("type", "text")
82
 
83
  return result
84
 
85
- def _process_single_source(
86
- self, data_source: str, seed_data: list[dict]
87
- ) -> list[dict]:
88
- """
89
- process a single data source: check cache, search missing, update cache.
90
  """
91
- searcher = self.searchers[data_source]
92
-
93
- seeds_with_ids = []
94
- for seed in seed_data:
95
- query = seed.get("content", "")
96
- if not query:
97
- continue
98
- doc_id = compute_content_hash(str(data_source) + query, "doc-")
99
- seeds_with_ids.append((doc_id, seed))
100
-
101
- if not seeds_with_ids:
102
- return []
103
-
104
- doc_ids = [doc_id for doc_id, _ in seeds_with_ids]
105
- cached_results = self.search_storage.get_by_ids(doc_ids)
106
-
107
- to_search_seeds = []
108
- final_results = []
109
 
110
- for (doc_id, seed), cached in zip(seeds_with_ids, cached_results):
111
- if cached is not None:
112
- if "_doc_id" not in cached:
113
- cached["_doc_id"] = doc_id
114
- final_results.append(cached)
115
- else:
116
- to_search_seeds.append(seed)
117
-
118
- if to_search_seeds:
119
- new_results = run_concurrent(
120
- partial(
121
- self._perform_search, searcher_obj=searcher, data_source=data_source
122
- ),
123
- to_search_seeds,
124
- desc=f"Searching {data_source} database",
125
- unit="keyword",
126
- )
127
- new_results = [res for res in new_results if res is not None]
128
-
129
- if new_results:
130
- upsert_data = {res["_doc_id"]: res for res in new_results}
131
- self.search_storage.upsert(upsert_data)
132
- logger.info(
133
- f"Saved {len(upsert_data)} new results to {data_source} cache"
134
- )
135
-
136
- final_results.extend(new_results)
137
-
138
- return final_results
139
-
140
- def process(self, batch: "pd.DataFrame") -> "pd.DataFrame":
141
- import pandas as pd
142
-
143
- docs = batch.to_dict(orient="records")
144
 
145
- self._init_searchers()
 
 
146
 
147
- seed_data = [doc for doc in docs if doc and "content" in doc]
 
 
 
148
 
149
  if not seed_data:
150
  logger.warning("No valid seeds in batch")
151
- return pd.DataFrame([])
152
-
153
- all_results = []
 
 
 
 
 
 
 
 
 
 
154
 
155
- for data_source in self.data_sources:
156
- if data_source not in self.searchers:
157
- logger.error(f"Data source {data_source} not initialized, skipping")
 
 
158
  continue
 
 
 
 
159
 
160
- source_results = self._process_single_source(data_source, seed_data)
161
- all_results.extend(source_results)
162
-
163
- if not all_results:
164
  logger.warning("No search results generated for this batch")
165
 
166
- return pd.DataFrame(all_results)
 
1
  from functools import partial
2
+ from typing import TYPE_CHECKING, Optional, Tuple
3
 
4
  from graphgen.bases import BaseOperator
5
  from graphgen.common.init_storage import init_storage
6
+ from graphgen.utils import logger, run_concurrent
7
 
8
  if TYPE_CHECKING:
9
  import pandas as pd
 
19
  self,
20
  working_dir: str = "cache",
21
  kv_backend: str = "rocksdb",
22
+ data_source: str = None,
23
  **kwargs,
24
  ):
25
+ super().__init__(
26
+ working_dir=working_dir, kv_backend=kv_backend, op_name="search"
27
+ )
28
+ self.data_source = data_source
29
  self.kwargs = kwargs
30
  self.search_storage = init_storage(
31
  backend=kv_backend, working_dir=working_dir, namespace="search"
32
  )
33
+ self.searcher = None
34
 
35
+ def _init_searcher(self):
36
  """
37
+ Initialize the searcher (deferred import to avoid circular imports).
38
  """
39
+ if self.searcher is not None:
40
+ return
41
+
42
+ if not self.data_source:
43
+ logger.error("Data source not specified")
44
+ return
45
 
46
+ if self.data_source == "uniprot":
47
+ from graphgen.models import UniProtSearch
 
 
48
 
49
+ params = self.kwargs.get("uniprot_params", {})
50
+ self.searcher = UniProtSearch(**params)
51
+ elif self.data_source == "ncbi":
52
+ from graphgen.models import NCBISearch
53
 
54
+ params = self.kwargs.get("ncbi_params", {})
55
+ self.searcher = NCBISearch(**params)
56
+ elif self.data_source == "rnacentral":
57
+ from graphgen.models import RNACentralSearch
58
+
59
+ params = self.kwargs.get("rnacentral_params", {})
60
+ self.searcher = RNACentralSearch(**params)
61
+ else:
62
+ logger.error(f"Unknown data source: {self.data_source}")
63
 
64
  @staticmethod
65
  async def _perform_search(
 
81
 
82
  result = searcher_obj.search(query)
83
  if result:
 
84
  result["data_source"] = data_source
85
  result["type"] = seed.get("type", "text")
86
 
87
  return result
88
 
89
+ def process(self, batch: list) -> Tuple[list, dict]:
 
 
 
 
90
  """
91
+ Search for items in the batch using the configured data source.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ :param batch: List of items with 'content' and '_trace_id' fields
94
+ :return: A tuple of (results, meta_updates)
95
+ results: A list of search results.
96
+ meta_updates: A dict mapping source IDs to lists of trace IDs for the search results.
97
+ """
98
+ self._init_searcher()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ if not self.searcher:
101
+ logger.error("Searcher not initialized")
102
+ return [], {}
103
 
104
+ # Filter seeds with valid content and _trace_id
105
+ seed_data = [
106
+ item for item in batch if item and "content" in item and "_trace_id" in item
107
+ ]
108
 
109
  if not seed_data:
110
  logger.warning("No valid seeds in batch")
111
+ return [], {}
112
+
113
+ # Perform concurrent searches
114
+ results = run_concurrent(
115
+ partial(
116
+ self._perform_search,
117
+ searcher_obj=self.searcher,
118
+ data_source=self.data_source,
119
+ ),
120
+ seed_data,
121
+ desc=f"Searching {self.data_source} database",
122
+ unit="keyword",
123
+ )
124
 
125
+ # Filter out None results and add _trace_id from original seeds
126
+ final_results = []
127
+ meta_updates = {}
128
+ for result, seed in zip(results, seed_data):
129
+ if result is None:
130
  continue
131
+ result["_trace_id"] = self.get_trace_id(result)
132
+ final_results.append(result)
133
+ # Map from source seed trace ID to search result trace ID
134
+ meta_updates.setdefault(seed["_trace_id"], []).append(result["_trace_id"])
135
 
136
+ if not final_results:
 
 
 
137
  logger.warning("No search results generated for this batch")
138
 
139
+ return final_results, meta_updates