Arjunvir Singh commited on
Commit
52764bf
·
1 Parent(s): f2ac076

Fix ZeroGPU state-loss in EmbeddingRetriever; add einops + smoke surfaces

Browse files

The @spaces.GPU decorator runs wrapped functions in a separate worker
process; mutations to `self` inside the worker do not propagate back to
the caller. Decorating EmbeddingRetriever.index() / .query() therefore
silently set the indexed vectors in the worker only, leaving the calling
instance with empty state and producing recall=0 even though the model
loaded and encoded successfully on GPU.

Refactor: GPU work moved to a free stateless helper
`_gpu_encode_batch(model_id, task, texts) -> vectors`. EmbeddingRetriever
methods stay in the main process and dispatch via a new `_encode` method
that picks the explicit injected embedder (test path) or the GPU helper
(production path). Same fix pattern applies to TransformersClient when
we wire live GPU repair end-to-end (deferred — current decorator is
single-shot bursty work where state loss doesn't matter for one call).

Also lands:
- einops>=0.7.0 in requirements.txt (jina-v3's xlm_roberta_flash custom
modeling needs it; sentence-transformers does not pull it in).
- pyproject.toml `embedding` extra updated to match.
- scripts/run_space_smoke.py honours ZSGDP_SMOKE_EMBEDDING_MODEL_ID env
var so operators can swap models without editing the script (e.g. to
sentence-transformers/all-MiniLM-L6-v2 when jina-v3 has transformers
compat issues).
- app.py exposes run_smokes_in_space as a callable function so the
smokes can be triggered from the Gradio API or a future button.

Test count: 250/250.

app.py CHANGED
@@ -219,6 +219,36 @@ def runtime_status_for_mode(pipeline_mode: str) -> dict:
219
  return collect_gpu_runtime_status(load_config(_config_path_for_mode(pipeline_mode))).to_dict()
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  with gr.Blocks(title="zeroshotGPU") as demo:
223
  gr.Markdown("# zeroshotGPU")
224
  with gr.Row():
 
219
  return collect_gpu_runtime_status(load_config(_config_path_for_mode(pipeline_mode))).to_dict()
220
 
221
 
222
+ def run_smokes_in_space() -> dict:
223
+ """Run scripts/run_space_smoke.py inside the Space and return the JSON report.
224
+
225
+ Exposes the in-process smoke runner as a Gradio endpoint so it's callable
226
+ from the UI tab AND from `/gradio_api/call/run_smokes_in_space` remotely.
227
+ Same code path as the terminal `python -m scripts.run_space_smoke` — just
228
+ triggered through Gradio instead of an SSH session.
229
+
230
+ Returns the same dict shape as SmokeReport.to_dict(): per-smoke results
231
+ with status / elapsed / detail / skip_reason / install_hint, plus an
232
+ aggregate summary count block.
233
+ """
234
+
235
+ from scripts.run_space_smoke import run_smokes
236
+
237
+ _logger.info("space_smokes_requested", extra={"trigger": "gradio_endpoint"})
238
+ report = run_smokes()
239
+ payload = report.to_dict()
240
+ _logger.info(
241
+ "space_smokes_complete",
242
+ extra={
243
+ "passed": payload["summary"]["passed"],
244
+ "failed": payload["summary"]["failed"],
245
+ "skipped": payload["summary"]["skipped"],
246
+ "errored": payload["summary"]["errored"],
247
+ },
248
+ )
249
+ return payload
250
+
251
+
252
  with gr.Blocks(title="zeroshotGPU") as demo:
253
  gr.Markdown("# zeroshotGPU")
254
  with gr.Row():
pyproject.toml CHANGED
@@ -25,7 +25,13 @@ spaces = [
25
  "pyyaml>=6.0.1,<7.0.0",
26
  "docling>=2.0.0,<3.0.0",
27
  ]
28
- embedding = ["sentence-transformers>=3.0.0,<4.0.0", "transformers>=4.45.0,<6.0.0"]
 
 
 
 
 
 
29
  gpu_repair = ["transformers>=4.45.0,<6.0.0"]
30
  dev = ["pytest>=8.0.0"]
31
 
 
25
  "pyyaml>=6.0.1,<7.0.0",
26
  "docling>=2.0.0,<3.0.0",
27
  ]
28
+ embedding = [
29
+ "sentence-transformers>=3.0.0,<4.0.0",
30
+ "transformers>=4.45.0,<6.0.0",
31
+ # jinaai/jina-embeddings-v3's custom modeling needs einops; not pulled
32
+ # in transitively by sentence-transformers.
33
+ "einops>=0.7.0",
34
+ ]
35
  gpu_repair = ["transformers>=4.45.0,<6.0.0"]
36
  dev = ["pytest>=8.0.0"]
37
 
requirements.txt CHANGED
@@ -21,12 +21,17 @@ docling>=2.0.0,<3.0.0
21
  # through to a passthrough decorator (see zsgdp/gpu/zero_gpu.py).
22
  spaces>=0.25.0
23
 
24
- # Optional GPU/embedding stack. Uncomment to enable the embedding retriever
25
- # (benchmarks.retriever.backend=embedding) and live GPU repair escalations
26
- # (repair.execute_gpu_escalations=true). Both are off by default.
 
 
 
 
27
  #
28
  transformers>=4.45.0,<6.0.0
29
  sentence-transformers>=3.0.0,<4.0.0
 
30
 
31
  # Optional external parser CLIs. Each adds a non-trivial install footprint;
32
  # enable only the ones the Space hardware can support. Adapter shells out to
 
21
  # through to a passthrough decorator (see zsgdp/gpu/zero_gpu.py).
22
  spaces>=0.25.0
23
 
24
+ # Embedding retriever + live GPU repair stack. Enabled here because the
25
+ # Space is provisioned for the full evaluation surface; comment out the
26
+ # group if you want a CPU-only deploy with just the lexical retriever.
27
+ #
28
+ # einops is required by jinaai/jina-embeddings-v3's custom modeling code
29
+ # (it ships a custom `xlm_roberta_flash` implementation that reshapes via
30
+ # einops); pip-installing sentence-transformers alone does not pull it in.
31
  #
32
  transformers>=4.45.0,<6.0.0
33
  sentence-transformers>=3.0.0,<4.0.0
34
+ einops>=0.7.0
35
 
36
  # Optional external parser CLIs. Each adds a non-trivial install footprint;
37
  # enable only the ones the Space hardware can support. Adapter shells out to
scripts/run_space_smoke.py CHANGED
@@ -179,6 +179,16 @@ def smoke_ablation() -> SmokeResult:
179
 
180
 
181
  def smoke_embedding() -> SmokeResult:
 
 
 
 
 
 
 
 
 
 
182
  started = time.perf_counter()
183
  if importlib.util.find_spec("sentence_transformers") is None:
184
  return SmokeResult(
@@ -189,15 +199,19 @@ def smoke_embedding() -> SmokeResult:
189
  install_hint="python -m pip install 'zero-shot-gpu-doc-parser[embedding]'",
190
  )
191
 
 
 
192
  from zsgdp.benchmarks.embedding_retriever import EmbeddingRetriever
193
  from zsgdp.benchmarks.parser_quality import run_parser_benchmark
194
 
 
 
195
  # Try to load the configured embedding model. If the load fails (no HF
196
  # token, download error, OOM at import time), we report it as a skip
197
  # with the exception text so the operator sees what to fix without the
198
  # whole smoke run blowing up.
199
  try:
200
- retriever = EmbeddingRetriever()
201
  retriever._ensure_embedder() # type: ignore[attr-defined] # private but intentional
202
  except Exception as exc:
203
  return SmokeResult(
@@ -205,21 +219,23 @@ def smoke_embedding() -> SmokeResult:
205
  status="skip",
206
  elapsed_seconds=time.perf_counter() - started,
207
  skip_reason=f"embedding model failed to load: {exc}",
208
- install_hint="Set HF_TOKEN if the model is gated, or downsize via "
209
- "benchmarks.retriever.model_id (e.g. sentence-transformers/all-MiniLM-L6-v2).",
 
210
  )
211
 
212
- config_overrides = {"benchmarks": {"retriever": {"backend": "embedding"}}}
213
  with tempfile.TemporaryDirectory() as tmp:
214
  tmp_path = Path(tmp)
215
  src = _make_distinctive_corpus(tmp_path)
216
  out = tmp_path / "out"
217
  config_path = tmp_path / "config.yaml"
218
- # Inline config write — keeps the smoke self-contained.
219
- config_path.write_text(
220
- "benchmarks:\n retriever:\n backend: embedding\n",
221
- encoding="utf-8",
222
- )
 
 
223
  try:
224
  summary = run_parser_benchmark(src, out, config_path=config_path, dataset_name="custom_folder")
225
  except Exception as exc:
 
179
 
180
 
181
  def smoke_embedding() -> SmokeResult:
182
+ """Validate the embedding-retriever wiring on a real Space.
183
+
184
+ Set ZSGDP_SMOKE_EMBEDDING_MODEL_ID to override the default model_id —
185
+ useful when the configured default (jinaai/jina-embeddings-v3) has
186
+ transformers-version compat issues with the running container. A
187
+ common safe fallback is `sentence-transformers/all-MiniLM-L6-v2`,
188
+ which has no custom remote modeling code and works with any
189
+ transformers version.
190
+ """
191
+
192
  started = time.perf_counter()
193
  if importlib.util.find_spec("sentence_transformers") is None:
194
  return SmokeResult(
 
199
  install_hint="python -m pip install 'zero-shot-gpu-doc-parser[embedding]'",
200
  )
201
 
202
+ import os
203
+
204
  from zsgdp.benchmarks.embedding_retriever import EmbeddingRetriever
205
  from zsgdp.benchmarks.parser_quality import run_parser_benchmark
206
 
207
+ override_model_id = os.environ.get("ZSGDP_SMOKE_EMBEDDING_MODEL_ID") or None
208
+
209
  # Try to load the configured embedding model. If the load fails (no HF
210
  # token, download error, OOM at import time), we report it as a skip
211
  # with the exception text so the operator sees what to fix without the
212
  # whole smoke run blowing up.
213
  try:
214
+ retriever = EmbeddingRetriever(model_id=override_model_id) if override_model_id else EmbeddingRetriever()
215
  retriever._ensure_embedder() # type: ignore[attr-defined] # private but intentional
216
  except Exception as exc:
217
  return SmokeResult(
 
219
  status="skip",
220
  elapsed_seconds=time.perf_counter() - started,
221
  skip_reason=f"embedding model failed to load: {exc}",
222
+ install_hint="Set HF_TOKEN if the model is gated, OR set "
223
+ "ZSGDP_SMOKE_EMBEDDING_MODEL_ID=sentence-transformers/all-MiniLM-L6-v2 "
224
+ "to use a smaller compat-friendly model.",
225
  )
226
 
 
227
  with tempfile.TemporaryDirectory() as tmp:
228
  tmp_path = Path(tmp)
229
  src = _make_distinctive_corpus(tmp_path)
230
  out = tmp_path / "out"
231
  config_path = tmp_path / "config.yaml"
232
+ # Inline config write — keeps the smoke self-contained. Honours the
233
+ # env-var model override so the operator can swap models without
234
+ # editing this script.
235
+ config_lines = ["benchmarks:", " retriever:", " backend: embedding"]
236
+ if override_model_id:
237
+ config_lines.append(f" model_id: {override_model_id}")
238
+ config_path.write_text("\n".join(config_lines) + "\n", encoding="utf-8")
239
  try:
240
  summary = run_parser_benchmark(src, out, config_path=config_path, dataset_name="custom_folder")
241
  except Exception as exc:
tests/test_space_smoke.py CHANGED
@@ -138,6 +138,28 @@ class RunSmokesIntegrationTests(unittest.TestCase):
138
  self.assertIn("sentence-transformers", result.skip_reason)
139
  self.assertIn("pip install", result.install_hint)
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  def test_marker_smoke_skips_when_binary_missing(self):
142
  with patch("scripts.run_space_smoke.shutil.which", return_value=None):
143
  result = smoke_marker()
 
138
  self.assertIn("sentence-transformers", result.skip_reason)
139
  self.assertIn("pip install", result.install_hint)
140
 
141
+ def test_embedding_smoke_install_hint_mentions_model_override(self):
142
+ # When the model fails to load (e.g. jina-v3 transformers compat),
143
+ # the install_hint must point at the env-var override path so the
144
+ # operator can immediately switch to a compat-friendly model.
145
+ # Patch where EmbeddingRetriever is *defined*, not where it's imported,
146
+ # because smoke_embedding does a function-local lazy import.
147
+ from unittest.mock import MagicMock
148
+
149
+ retriever_mock = MagicMock()
150
+ retriever_mock.return_value._ensure_embedder.side_effect = RuntimeError("synthetic load failure")
151
+
152
+ with patch("scripts.run_space_smoke.importlib.util.find_spec") as find_spec, patch(
153
+ "zsgdp.benchmarks.embedding_retriever.EmbeddingRetriever", retriever_mock
154
+ ):
155
+ find_spec.return_value = object() # spec found, dep present
156
+ result = smoke_embedding()
157
+
158
+ self.assertEqual(result.status, "skip")
159
+ self.assertIn("synthetic load failure", result.skip_reason)
160
+ self.assertIn("ZSGDP_SMOKE_EMBEDDING_MODEL_ID", result.install_hint)
161
+ self.assertIn("all-MiniLM-L6-v2", result.install_hint)
162
+
163
  def test_marker_smoke_skips_when_binary_missing(self):
164
  with patch("scripts.run_space_smoke.shutil.which", return_value=None):
165
  result = smoke_marker()
zsgdp/benchmarks/embedding_retriever.py CHANGED
@@ -12,8 +12,7 @@ Definitions and contract (pinned):
12
  - Pass `embedder=...` directly (used by tests and any caller that wants
13
  full control over batching, device placement, or remote inference).
14
  - Pass `model_id=...` and let the retriever lazy-load
15
- sentence-transformers. Selected through `build_retriever` from config
16
- by setting `benchmarks.retriever.backend = "embedding"`.
17
  - Index and query both call the embedder. The retriever is stateless
18
  beyond the indexed chunk vectors; reusing across documents requires a
19
  fresh `index()` call, same contract as LexicalRetriever.
@@ -22,6 +21,14 @@ Definitions and contract (pinned):
22
  Other sentence-transformers models work as long as they accept the same
23
  encode signature; jina's task-prompt argument is optional and silently
24
  ignored by models that don't recognize it.
 
 
 
 
 
 
 
 
25
  """
26
 
27
  from __future__ import annotations
@@ -34,6 +41,38 @@ from zsgdp.schema import Chunk
34
  Embedder = Callable[[list[str]], list[list[float]]]
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  class EmbeddingRetriever:
38
  def __init__(
39
  self,
@@ -51,17 +90,16 @@ class EmbeddingRetriever:
51
  self._chunk_ids: list[str] = []
52
  self._vectors: list[list[float]] = []
53
 
54
- @zero_gpu_slot(duration=180)
55
  def index(self, chunks: Sequence[Chunk]) -> None:
56
- # First call lazy-loads the model + encodes all chunks (the slow path);
57
- # ZeroGPU slot covers both. No-op decorator off-Space.
58
- embedder = self._ensure_embedder()
59
  texts = [chunk.text for chunk in chunks]
60
  if not texts:
61
  self._chunk_ids = []
62
  self._vectors = []
63
  return
64
- vectors = embedder(texts)
65
  if len(vectors) != len(texts):
66
  raise RuntimeError(
67
  f"EmbeddingRetriever embedder returned {len(vectors)} vectors for {len(texts)} chunks."
@@ -69,12 +107,10 @@ class EmbeddingRetriever:
69
  self._chunk_ids = [chunk.chunk_id for chunk in chunks]
70
  self._vectors = [_normalize(list(vector)) for vector in vectors]
71
 
72
- @zero_gpu_slot(duration=30)
73
  def query(self, text: str, *, top_k: int) -> list[str]:
74
  if not self._vectors:
75
  return []
76
- embedder = self._ensure_embedder()
77
- query_vec = embedder([text])
78
  if not query_vec:
79
  return []
80
  query_vector = _normalize(list(query_vec[0]))
@@ -88,6 +124,18 @@ class EmbeddingRetriever:
88
  scored.sort(key=lambda item: (-item[0], item[1]))
89
  return [self._chunk_ids[index] for _score, index in scored[:top_k]]
90
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def _ensure_embedder(self) -> Embedder:
92
  if self._embedder is not None:
93
  return self._embedder
 
12
  - Pass `embedder=...` directly (used by tests and any caller that wants
13
  full control over batching, device placement, or remote inference).
14
  - Pass `model_id=...` and let the retriever lazy-load
15
+ sentence-transformers via the stateless `_gpu_encode_batch` helper.
 
16
  - Index and query both call the embedder. The retriever is stateless
17
  beyond the indexed chunk vectors; reusing across documents requires a
18
  fresh `index()` call, same contract as LexicalRetriever.
 
21
  Other sentence-transformers models work as long as they accept the same
22
  encode signature; jina's task-prompt argument is optional and silently
23
  ignored by models that don't recognize it.
24
+
25
+ ZeroGPU note: the GPU slot decorator runs the wrapped function in a
26
+ separate worker process. Mutations to `self` made inside the worker do
27
+ NOT propagate back to the caller. So `index()` and `query()` are
28
+ intentionally NOT decorated — the GPU work is offloaded to the free
29
+ stateless `_gpu_encode_batch(model_id, task, texts) -> vectors` helper,
30
+ and the calling EmbeddingRetriever instance (which holds chunk_ids and
31
+ vectors) stays in the main process.
32
  """
33
 
34
  from __future__ import annotations
 
41
  Embedder = Callable[[list[str]], list[list[float]]]
42
 
43
 
44
+ @zero_gpu_slot(duration=180)
45
+ def _gpu_encode_batch(model_id: str, task: str | None, texts: list[str]) -> list[list[float]]:
46
+ """Load a sentence-transformers model and encode `texts` under a ZeroGPU slot.
47
+
48
+ Stateless by design: takes only picklable inputs (strings) and returns a
49
+ list-of-lists of floats. The model is loaded fresh inside the worker
50
+ process — that's where ZeroGPU has GPU access. Subsequent calls re-load
51
+ (acceptable for bursty workloads); for sustained-throughput workloads,
52
+ pin the Space to non-ZeroGPU hardware and inject an `embedder` callable
53
+ so the model stays warm in the main process.
54
+ """
55
+
56
+ try:
57
+ from sentence_transformers import SentenceTransformer # type: ignore
58
+ except ImportError as exc:
59
+ raise RuntimeError(
60
+ "EmbeddingRetriever requires sentence-transformers. "
61
+ "Install with `pip install sentence-transformers` or pass `embedder=...` explicitly."
62
+ ) from exc
63
+
64
+ model = SentenceTransformer(model_id, trust_remote_code=True)
65
+ kwargs: dict[str, Any] = {"normalize_embeddings": True}
66
+ if task:
67
+ try:
68
+ vectors = model.encode(texts, task=task, **kwargs)
69
+ except TypeError:
70
+ vectors = model.encode(texts, **kwargs)
71
+ else:
72
+ vectors = model.encode(texts, **kwargs)
73
+ return [list(map(float, vector)) for vector in vectors]
74
+
75
+
76
  class EmbeddingRetriever:
77
  def __init__(
78
  self,
 
90
  self._chunk_ids: list[str] = []
91
  self._vectors: list[list[float]] = []
92
 
 
93
  def index(self, chunks: Sequence[Chunk]) -> None:
94
+ # NOT decorated with @zero_gpu_slot see module docstring. The GPU
95
+ # work is offloaded to the stateless _gpu_encode_batch helper so
96
+ # mutations to self stay in the main process.
97
  texts = [chunk.text for chunk in chunks]
98
  if not texts:
99
  self._chunk_ids = []
100
  self._vectors = []
101
  return
102
+ vectors = self._encode(texts, task=self._task)
103
  if len(vectors) != len(texts):
104
  raise RuntimeError(
105
  f"EmbeddingRetriever embedder returned {len(vectors)} vectors for {len(texts)} chunks."
 
107
  self._chunk_ids = [chunk.chunk_id for chunk in chunks]
108
  self._vectors = [_normalize(list(vector)) for vector in vectors]
109
 
 
110
  def query(self, text: str, *, top_k: int) -> list[str]:
111
  if not self._vectors:
112
  return []
113
+ query_vec = self._encode([text], task=self._query_task)
 
114
  if not query_vec:
115
  return []
116
  query_vector = _normalize(list(query_vec[0]))
 
124
  scored.sort(key=lambda item: (-item[0], item[1]))
125
  return [self._chunk_ids[index] for _score, index in scored[:top_k]]
126
 
127
+ def _encode(self, texts: list[str], *, task: str | None) -> list[list[float]]:
128
+ """Dispatch encode to the injected embedder or the GPU helper.
129
+
130
+ Test path: `embedder=...` was passed to __init__, runs in-process.
131
+ Production path: model_id was passed (default jina-v3), runs inside
132
+ the @spaces.GPU-decorated worker via _gpu_encode_batch.
133
+ """
134
+
135
+ if self._explicit_embedder is not None:
136
+ return self._explicit_embedder(texts)
137
+ return _gpu_encode_batch(self._model_id, task, texts)
138
+
139
  def _ensure_embedder(self) -> Embedder:
140
  if self._embedder is not None:
141
  return self._embedder