KinetoLabs Claude Opus 4.5 commited on
Commit
5f0db1e
·
1 Parent(s): c190082

Implement lazy model loading to prevent CUDA OOM on 4xL4 GPUs

Browse files

Problem: All 3 models (~92GB) loaded at startup exceeded 88GB VRAM.

Solution: Sequential loading - vision model during Stage 2, RAG models
during Stage 3+. Vision is unloaded before RAG loads. Peak: ~60GB.

Changes:
- models/real.py: Add load_vision(), unload_vision(), load_rag() with
proper hook removal per HuggingFace accelerate docs
- models/loader.py: Real models now use lazy loading (no load_all)
- pipeline/main.py: Load/unload at appropriate pipeline stages
- rag/vectorstore.py: Use SharedEmbeddingFunction (no duplicate load)
- rag/retriever.py: Use SharedReranker (no duplicate load)
- models/mock.py: Add is_vision_loaded(), is_rag_loaded() for API parity

Memory profile:
- Phase A (Vision): 30B model ~60GB
- Transition: Unload + gc + empty_cache
- Phase B (RAG): 8B + 8B ~32GB
- Peak never exceeds 60GB (fits in 88GB)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Files changed (6) hide show
  1. models/loader.py +26 -6
  2. models/mock.py +14 -1
  3. models/real.py +134 -19
  4. pipeline/main.py +15 -0
  5. rag/retriever.py +21 -55
  6. rag/vectorstore.py +21 -79
models/loader.py CHANGED
@@ -1,4 +1,14 @@
1
- """Model loading with mock/real switching based on environment."""
 
 
 
 
 
 
 
 
 
 
2
 
3
  import logging
4
  import time
@@ -16,7 +26,11 @@ _model_stack: ModelStack | None = None
16
 
17
 
18
  def get_model_stack() -> ModelStack:
19
- """Get model stack based on environment configuration."""
 
 
 
 
20
  start_time = time.time()
21
 
22
  if settings.mock_models:
@@ -28,20 +42,26 @@ def get_model_stack() -> ModelStack:
28
  logger.info(f"Mock model stack loaded in {elapsed:.2f}s")
29
  return stack
30
  else:
31
- logger.info("Loading REAL model stack (production mode)")
32
  logger.info(f"Vision model: {settings.vision_model}")
33
  logger.info(f"Embedding model: {settings.embedding_model}")
34
  logger.info(f"Reranker model: {settings.reranker_model}")
 
35
  from models.real import RealModelStack
36
 
37
- stack = RealModelStack().load_all()
 
38
  elapsed = time.time() - start_time
39
- logger.info(f"Real model stack loaded in {elapsed:.2f}s")
40
  return stack
41
 
42
 
43
  def get_models() -> ModelStack:
44
- """Get or create the singleton model stack."""
 
 
 
 
45
  global _model_stack
46
  if _model_stack is None:
47
  logger.debug("Model stack not initialized, creating new stack")
 
1
+ """Model loading with mock/real switching based on environment.
2
+
3
+ Supports two loading modes:
4
+ - MOCK_MODELS=true: Loads all mock models at startup (fast, for local dev)
5
+ - MOCK_MODELS=false: Uses LAZY LOADING (models loaded on-demand by pipeline)
6
+
7
+ Lazy Loading Strategy (for 4xL4 GPUs with 88GB total):
8
+ - Vision 30B (~60GB) loaded before Stage 2, unloaded after
9
+ - RAG models (~32GB) loaded before Stage 3
10
+ - Peak usage ~60GB, never both simultaneously
11
+ """
12
 
13
  import logging
14
  import time
 
26
 
27
 
28
  def get_model_stack() -> ModelStack:
29
+ """Get model stack based on environment configuration.
30
+
31
+ For mock models: Loads all models immediately (fast, for local dev).
32
+ For real models: Returns uninitialized stack for lazy loading.
33
+ """
34
  start_time = time.time()
35
 
36
  if settings.mock_models:
 
42
  logger.info(f"Mock model stack loaded in {elapsed:.2f}s")
43
  return stack
44
  else:
45
+ logger.info("Creating REAL model stack (production mode - lazy loading)")
46
  logger.info(f"Vision model: {settings.vision_model}")
47
  logger.info(f"Embedding model: {settings.embedding_model}")
48
  logger.info(f"Reranker model: {settings.reranker_model}")
49
+ logger.info("NOTE: Models will be loaded on-demand by pipeline stages")
50
  from models.real import RealModelStack
51
 
52
+ # Don't load models yet - pipeline will call load_vision() and load_rag()
53
+ stack = RealModelStack()
54
  elapsed = time.time() - start_time
55
+ logger.info(f"Real model stack initialized in {elapsed:.2f}s (no models loaded yet)")
56
  return stack
57
 
58
 
59
  def get_models() -> ModelStack:
60
+ """Get or create the singleton model stack.
61
+
62
+ For real models, this returns an uninitialized stack.
63
+ Call stack.load_vision() or stack.load_rag() as needed.
64
+ """
65
  global _model_stack
66
  if _model_stack is None:
67
  logger.debug("Model stack not initialized, creating new stack")
models/mock.py CHANGED
@@ -186,7 +186,12 @@ class MockRerankerModel:
186
 
187
 
188
  class MockModelStack:
189
- """Mock model stack for local development."""
 
 
 
 
 
190
 
191
  def __init__(self):
192
  self.vision = MockVisionModel()
@@ -207,3 +212,11 @@ class MockModelStack:
207
  def is_loaded(self) -> bool:
208
  """Check if models are loaded."""
209
  return self.loaded
 
 
 
 
 
 
 
 
 
186
 
187
 
188
  class MockModelStack:
189
+ """Mock model stack for local development.
190
+
191
+ Unlike RealModelStack, mock models are always loaded together.
192
+ The is_vision_loaded() and is_rag_loaded() methods are provided
193
+ for API compatibility with the lazy loading pipeline.
194
+ """
195
 
196
  def __init__(self):
197
  self.vision = MockVisionModel()
 
212
  def is_loaded(self) -> bool:
213
  """Check if models are loaded."""
214
  return self.loaded
215
+
216
+ def is_vision_loaded(self) -> bool:
217
+ """Check if vision model is loaded (always True when loaded)."""
218
+ return self.loaded
219
+
220
+ def is_rag_loaded(self) -> bool:
221
+ """Check if RAG models are loaded (always True when loaded)."""
222
+ return self.loaded
models/real.py CHANGED
@@ -1,7 +1,13 @@
1
  """Real model loading for production (HuggingFace Spaces with 4xL4 GPUs).
2
 
3
  This module loads the actual Qwen3-VL models for production use.
4
- Requires ~90GB VRAM (4xL4 with 96GB total).
 
 
 
 
 
 
5
 
6
  Model Loading:
7
  - Vision: Qwen3VLMoeForConditionalGeneration (standard transformers)
@@ -9,6 +15,7 @@ Model Loading:
9
  - Reranker: Qwen3VLReranker (official scripts from QwenLM/Qwen3-VL-Embedding)
10
  """
11
 
 
12
  import json
13
  import logging
14
  import re
@@ -24,27 +31,48 @@ logger = logging.getLogger(__name__)
24
 
25
 
26
  class RealModelStack:
27
- """Real model stack for production on HuggingFace Spaces."""
 
 
 
 
 
 
28
 
29
  def __init__(self):
30
  self.models: dict[str, Any] = {}
31
  self.processors: dict[str, Any] = {}
32
- self.loaded = False
33
-
34
- def load_all(self) -> "RealModelStack":
35
- """Load all models with device_map='auto' for multi-GPU distribution."""
36
- from transformers import AutoProcessor
37
 
38
- device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
39
- logger.info(f"Loading models on {device_type}")
40
  if torch.cuda.is_available():
41
  gpu_count = torch.cuda.device_count()
42
- logger.info(f"CUDA devices available: {gpu_count}")
43
  for i in range(gpu_count):
44
- mem_gb = torch.cuda.get_device_properties(i).total_memory / (1024**3)
45
- logger.info(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem_gb:.1f} GB)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Vision model (~58GB in BF16)
48
  logger.info(f"Loading vision model: {settings.vision_model}")
49
  vision_start = time.time()
50
  try:
@@ -64,6 +92,8 @@ class RealModelStack:
64
  except Exception as e:
65
  logger.warning(f"Failed to load 30B vision model: {e}")
66
  logger.info(f"Falling back to {settings.vision_model_fallback}")
 
 
67
  self.models["vision"] = Qwen3VLMoeForConditionalGeneration.from_pretrained(
68
  settings.vision_model_fallback,
69
  torch_dtype=torch.bfloat16,
@@ -76,6 +106,66 @@ class RealModelStack:
76
  )
77
  logger.info(f"Fallback vision model loaded in {time.time() - vision_start:.2f}s")
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  # Embedding model (~16GB in BF16) - Using official Qwen3VLEmbedder
80
  logger.info(f"Loading embedding model: {settings.embedding_model}")
81
  embed_start = time.time()
@@ -85,7 +175,6 @@ class RealModelStack:
85
  model_name_or_path=settings.embedding_model,
86
  torch_dtype=torch.bfloat16,
87
  )
88
- # Processor is internal to Qwen3VLEmbedder, but store reference for compatibility
89
  self.processors["embedding"] = self.models["embedding"].processor
90
  logger.info(f"Embedding model loaded in {time.time() - embed_start:.2f}s")
91
 
@@ -98,31 +187,57 @@ class RealModelStack:
98
  model_name_or_path=settings.reranker_model,
99
  torch_dtype=torch.bfloat16,
100
  )
101
- # Processor is internal to Qwen3VLReranker, but store reference for compatibility
102
  self.processors["reranker"] = self.models["reranker"].processor
103
  logger.info(f"Reranker model loaded in {time.time() - reranker_start:.2f}s")
104
 
105
- self.loaded = True
106
- logger.info("All models loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  return self
108
 
109
  def is_loaded(self) -> bool:
110
- """Check if models are loaded."""
111
- return self.loaded
 
 
 
 
 
 
 
 
112
 
113
  @property
114
  def vision(self) -> "RealVisionModel":
115
  """Return vision model wrapped for pipeline consumption."""
 
 
116
  return RealVisionModel(self.models["vision"], self.processors["vision"])
117
 
118
  @property
119
  def embedding(self) -> "RealEmbeddingModel":
120
  """Return embedding model wrapped for pipeline consumption."""
 
 
121
  return RealEmbeddingModel(self.models["embedding"], self.processors["embedding"])
122
 
123
  @property
124
  def reranker(self) -> "RealRerankerModel":
125
  """Return reranker model wrapped for pipeline consumption."""
 
 
126
  return RealRerankerModel(self.models["reranker"], self.processors["reranker"])
127
 
128
 
 
1
  """Real model loading for production (HuggingFace Spaces with 4xL4 GPUs).
2
 
3
  This module loads the actual Qwen3-VL models for production use.
4
+ Uses LAZY LOADING to fit within 88GB VRAM (4xL4 with ~22GB each).
5
+
6
+ Memory Strategy:
7
+ - Vision 30B (~60GB): Loaded ONLY during Stage 2 (Vision Analysis)
8
+ - Embedding 8B (~16GB): Loaded ONLY during Stages 3+ (RAG)
9
+ - Reranker 8B (~16GB): Loaded ONLY during Stages 3+ (RAG)
10
+ - Peak usage: ~60GB (never all three simultaneously)
11
 
12
  Model Loading:
13
  - Vision: Qwen3VLMoeForConditionalGeneration (standard transformers)
 
15
  - Reranker: Qwen3VLReranker (official scripts from QwenLM/Qwen3-VL-Embedding)
16
  """
17
 
18
+ import gc
19
  import json
20
  import logging
21
  import re
 
31
 
32
 
33
  class RealModelStack:
34
+ """Real model stack for production on HuggingFace Spaces.
35
+
36
+ Uses LAZY LOADING to prevent OOM errors on 4xL4 (88GB total):
37
+ - Vision 30B (~60GB) and RAG models (~32GB) are never loaded simultaneously
38
+ - Pipeline calls load_vision() before Stage 2, unload_vision() after
39
+ - Pipeline calls load_rag() before Stage 3
40
+ """
41
 
42
  def __init__(self):
43
  self.models: dict[str, Any] = {}
44
  self.processors: dict[str, Any] = {}
45
+ self._vision_loaded = False
46
+ self._rag_loaded = False
 
 
 
47
 
48
+ def _log_gpu_status(self):
49
+ """Log current GPU memory status."""
50
  if torch.cuda.is_available():
51
  gpu_count = torch.cuda.device_count()
52
+ logger.info(f"GPU memory status ({gpu_count} devices):")
53
  for i in range(gpu_count):
54
+ total = torch.cuda.get_device_properties(i).total_memory / (1024**3)
55
+ allocated = torch.cuda.memory_allocated(i) / (1024**3)
56
+ cached = torch.cuda.memory_reserved(i) / (1024**3)
57
+ free = total - allocated
58
+ logger.info(f" GPU {i}: {allocated:.1f}GB allocated, {cached:.1f}GB cached, {free:.1f}GB free / {total:.1f}GB total")
59
+
60
+ def load_vision(self) -> "RealModelStack":
61
+ """Load only the vision model (~60GB in BF16).
62
+
63
+ Call this before Stage 2 (Vision Analysis).
64
+ Must call unload_vision() before load_rag() to free memory.
65
+ """
66
+ if self._vision_loaded:
67
+ logger.debug("Vision model already loaded, skipping")
68
+ return self
69
+
70
+ from transformers import AutoProcessor
71
+
72
+ device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
73
+ logger.info(f"Loading vision model on {device_type}")
74
+ self._log_gpu_status()
75
 
 
76
  logger.info(f"Loading vision model: {settings.vision_model}")
77
  vision_start = time.time()
78
  try:
 
92
  except Exception as e:
93
  logger.warning(f"Failed to load 30B vision model: {e}")
94
  logger.info(f"Falling back to {settings.vision_model_fallback}")
95
+ from transformers import Qwen3VLMoeForConditionalGeneration
96
+
97
  self.models["vision"] = Qwen3VLMoeForConditionalGeneration.from_pretrained(
98
  settings.vision_model_fallback,
99
  torch_dtype=torch.bfloat16,
 
106
  )
107
  logger.info(f"Fallback vision model loaded in {time.time() - vision_start:.2f}s")
108
 
109
+ self._vision_loaded = True
110
+ self._log_gpu_status()
111
+ return self
112
+
113
+ def unload_vision(self):
114
+ """Unload vision model and free CUDA memory.
115
+
116
+ Uses accelerate's remove_hook_from_module per HuggingFace docs.
117
+ Call this after Stage 2 (Vision Analysis) to free memory for RAG.
118
+ """
119
+ if not self._vision_loaded or "vision" not in self.models:
120
+ logger.debug("Vision model not loaded, skipping unload")
121
+ return
122
+
123
+ logger.info("Unloading vision model to free memory for RAG...")
124
+ self._log_gpu_status()
125
+
126
+ try:
127
+ from accelerate.hooks import remove_hook_from_module
128
+
129
+ # CRITICAL: Remove hooks before deleting (required for device_map="auto")
130
+ model = self.models["vision"]
131
+ if hasattr(model, 'model'):
132
+ # Some wrappers have nested model
133
+ remove_hook_from_module(model.model, recurse=True)
134
+ remove_hook_from_module(model, recurse=True)
135
+ logger.debug("Accelerate hooks removed from vision model")
136
+ except ImportError:
137
+ logger.warning("accelerate.hooks not available, proceeding with basic cleanup")
138
+ except Exception as e:
139
+ logger.warning(f"Hook removal failed (continuing anyway): {e}")
140
+
141
+ # Delete model and processor
142
+ del self.models["vision"]
143
+ del self.processors["vision"]
144
+ self._vision_loaded = False
145
+
146
+ # Clear CUDA cache (may not free 100% but sufficient for sequential loading)
147
+ gc.collect()
148
+ torch.cuda.empty_cache()
149
+
150
+ logger.info("Vision model unloaded, CUDA cache cleared")
151
+ self._log_gpu_status()
152
+
153
+ def load_rag(self) -> "RealModelStack":
154
+ """Load embedding and reranker models (~32GB total in BF16).
155
+
156
+ Call this before Stage 3 (RAG Retrieval).
157
+ Must call unload_vision() first to have enough memory.
158
+ """
159
+ if self._rag_loaded:
160
+ logger.debug("RAG models already loaded, skipping")
161
+ return self
162
+
163
+ if self._vision_loaded:
164
+ logger.warning("Vision model still loaded! Call unload_vision() first to avoid OOM.")
165
+
166
+ logger.info("Loading RAG models (embedding + reranker)...")
167
+ self._log_gpu_status()
168
+
169
  # Embedding model (~16GB in BF16) - Using official Qwen3VLEmbedder
170
  logger.info(f"Loading embedding model: {settings.embedding_model}")
171
  embed_start = time.time()
 
175
  model_name_or_path=settings.embedding_model,
176
  torch_dtype=torch.bfloat16,
177
  )
 
178
  self.processors["embedding"] = self.models["embedding"].processor
179
  logger.info(f"Embedding model loaded in {time.time() - embed_start:.2f}s")
180
 
 
187
  model_name_or_path=settings.reranker_model,
188
  torch_dtype=torch.bfloat16,
189
  )
 
190
  self.processors["reranker"] = self.models["reranker"].processor
191
  logger.info(f"Reranker model loaded in {time.time() - reranker_start:.2f}s")
192
 
193
+ self._rag_loaded = True
194
+ logger.info("RAG models loaded successfully")
195
+ self._log_gpu_status()
196
+ return self
197
+
198
+ def load_all(self) -> "RealModelStack":
199
+ """Load all models (DEPRECATED - use lazy loading instead).
200
+
201
+ This method is kept for backward compatibility but will cause OOM
202
+ on 4xL4 GPUs. Use load_vision() and load_rag() sequentially instead.
203
+ """
204
+ logger.warning("load_all() is deprecated - use load_vision() and load_rag() for lazy loading")
205
+ self.load_vision()
206
+ # Note: This WILL cause OOM on 4xL4 as vision (60GB) + RAG (32GB) > 88GB
207
+ self.load_rag()
208
  return self
209
 
210
  def is_loaded(self) -> bool:
211
+ """Check if any models are loaded."""
212
+ return self._vision_loaded or self._rag_loaded
213
+
214
+ def is_vision_loaded(self) -> bool:
215
+ """Check if vision model is loaded."""
216
+ return self._vision_loaded
217
+
218
+ def is_rag_loaded(self) -> bool:
219
+ """Check if RAG models are loaded."""
220
+ return self._rag_loaded
221
 
222
  @property
223
  def vision(self) -> "RealVisionModel":
224
  """Return vision model wrapped for pipeline consumption."""
225
+ if not self._vision_loaded:
226
+ raise RuntimeError("Vision model not loaded. Call load_vision() first.")
227
  return RealVisionModel(self.models["vision"], self.processors["vision"])
228
 
229
  @property
230
  def embedding(self) -> "RealEmbeddingModel":
231
  """Return embedding model wrapped for pipeline consumption."""
232
+ if not self._rag_loaded:
233
+ raise RuntimeError("Embedding model not loaded. Call load_rag() first.")
234
  return RealEmbeddingModel(self.models["embedding"], self.processors["embedding"])
235
 
236
  @property
237
  def reranker(self) -> "RealRerankerModel":
238
  """Return reranker model wrapped for pipeline consumption."""
239
+ if not self._rag_loaded:
240
+ raise RuntimeError("Reranker model not loaded. Call load_rag() first.")
241
  return RealRerankerModel(self.models["reranker"], self.processors["reranker"])
242
 
243
 
pipeline/main.py CHANGED
@@ -199,6 +199,11 @@ class FDAMPipeline:
199
  logger.info(f"Stage 2/6: Vision Analysis ({len(session.images)} images)")
200
  report_progress(2, "Analyzing images with AI...")
201
  model_stack = get_models()
 
 
 
 
 
202
  vision_results = {}
203
  annotated_images = []
204
  room_mapping = {}
@@ -259,10 +264,20 @@ class FDAMPipeline:
259
  logger.info(f"Stage 2 completed in {time.time() - stage_start:.2f}s: "
260
  f"{len(vision_results)} images analyzed")
261
 
 
 
 
 
 
262
  # Stage 3: RAG Retrieval
263
  stage_start = time.time()
264
  logger.info("Stage 3/6: RAG Retrieval")
265
  report_progress(3, "Retrieving FDAM methodology context...")
 
 
 
 
 
266
  # RAG is integrated into disposition engine, just verify connection
267
  try:
268
  test_results = self.retriever.retrieve("test connection", top_k=1)
 
199
  logger.info(f"Stage 2/6: Vision Analysis ({len(session.images)} images)")
200
  report_progress(2, "Analyzing images with AI...")
201
  model_stack = get_models()
202
+
203
+ # Lazy load vision model (for real models only - mock models are already loaded)
204
+ if hasattr(model_stack, 'load_vision') and not model_stack.is_vision_loaded():
205
+ logger.info("Lazy loading vision model...")
206
+ model_stack.load_vision()
207
  vision_results = {}
208
  annotated_images = []
209
  room_mapping = {}
 
264
  logger.info(f"Stage 2 completed in {time.time() - stage_start:.2f}s: "
265
  f"{len(vision_results)} images analyzed")
266
 
267
+ # Unload vision model to free memory for RAG (for real models only)
268
+ if hasattr(model_stack, 'unload_vision') and model_stack.is_vision_loaded():
269
+ logger.info("Unloading vision model to free memory for RAG...")
270
+ model_stack.unload_vision()
271
+
272
  # Stage 3: RAG Retrieval
273
  stage_start = time.time()
274
  logger.info("Stage 3/6: RAG Retrieval")
275
  report_progress(3, "Retrieving FDAM methodology context...")
276
+
277
+ # Lazy load RAG models (for real models only - mock models are already loaded)
278
+ if hasattr(model_stack, 'load_rag') and not model_stack.is_rag_loaded():
279
+ logger.info("Lazy loading RAG models (embedding + reranker)...")
280
+ model_stack.load_rag()
281
  # RAG is integrated into disposition engine, just verify connection
282
  try:
283
  test_results = self.retriever.retrieve("test connection", top_k=1)
rag/retriever.py CHANGED
@@ -84,84 +84,50 @@ class MockReranker:
84
  return scores
85
 
86
 
87
- class RealReranker:
88
- """Real reranker using Qwen3-VL-Reranker-8B.
89
 
90
- Loaded on-demand when MOCK_MODELS=false.
 
91
  """
92
 
93
- def __init__(self):
94
- self.model = None
95
- self.tokenizer = None
96
-
97
- def _load_model(self):
98
- """Lazy load the reranker model."""
99
- if self.model is not None:
100
- return
101
-
102
- import torch
103
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
104
-
105
- model_name = "Qwen/Qwen3-VL-Reranker-8B"
106
- logger.info(f"Loading reranker model: {model_name}")
107
-
108
- self.tokenizer = AutoTokenizer.from_pretrained(
109
- model_name,
110
- trust_remote_code=True,
111
- )
112
- self.model = AutoModelForSequenceClassification.from_pretrained(
113
- model_name,
114
- torch_dtype=torch.bfloat16,
115
- device_map="auto",
116
- trust_remote_code=True,
117
- )
118
- self.model.eval()
119
-
120
  def rerank(
121
  self,
122
  query: str,
123
  documents: list[str],
124
  ) -> list[float]:
125
- """Score documents using the reranker model.
126
 
127
  Args:
128
  query: Query text
129
  documents: List of document texts
130
 
131
  Returns:
132
- List of scores for each document
133
  """
134
- self._load_model()
135
 
136
- import torch
137
 
138
- scores = []
139
- with torch.no_grad():
140
- for doc in documents:
141
- inputs = self.tokenizer(
142
- query,
143
- doc,
144
- return_tensors="pt",
145
- truncation=True,
146
- max_length=512,
147
- padding=True,
148
- )
149
- # Note: With device_map="auto", transformers handles device routing internally
150
- # Do NOT call .to(device) - it breaks distributed models
151
 
152
- outputs = self.model(**inputs)
153
- # Sigmoid to get 0-1 score
154
- score = torch.sigmoid(outputs.logits).squeeze().item()
155
- scores.append(score)
156
-
157
- return scores
158
 
159
 
160
  def get_reranker():
161
- """Get appropriate reranker based on settings."""
 
 
 
 
162
  if settings.mock_models:
163
  return MockReranker()
164
- return RealReranker()
165
 
166
 
167
  class FDAMRetriever:
 
84
  return scores
85
 
86
 
87
+ class SharedReranker:
88
+ """Reranker that uses the shared model from RealModelStack.
89
 
90
+ This avoids loading a duplicate reranker model - instead uses the
91
+ model already loaded by the pipeline via model_stack.load_rag().
92
  """
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def rerank(
95
  self,
96
  query: str,
97
  documents: list[str],
98
  ) -> list[float]:
99
+ """Score documents using the shared reranker model.
100
 
101
  Args:
102
  query: Query text
103
  documents: List of document texts
104
 
105
  Returns:
106
+ List of scores (0-1) for each document
107
  """
108
+ from models.loader import get_models
109
 
110
+ model_stack = get_models()
111
 
112
+ # Check if RAG models are loaded
113
+ if not model_stack.is_rag_loaded():
114
+ logger.warning("RAG models not loaded yet - reranking may fail")
115
+ # Return neutral scores as fallback
116
+ return [0.5] * len(documents)
 
 
 
 
 
 
 
 
117
 
118
+ # Use the shared reranker model
119
+ return model_stack.reranker.rerank(query, documents)
 
 
 
 
120
 
121
 
122
  def get_reranker():
123
+ """Get appropriate reranker based on settings.
124
+
125
+ For real models, uses SharedReranker which wraps the
126
+ model stack's reranker model (no duplicate loading).
127
+ """
128
  if settings.mock_models:
129
  return MockReranker()
130
+ return SharedReranker()
131
 
132
 
133
  class FDAMRetriever:
rag/vectorstore.py CHANGED
@@ -58,100 +58,42 @@ class MockEmbeddingFunction:
58
  return embedding
59
 
60
 
61
- class RealEmbeddingFunction:
62
- """Real embedding function using Qwen3-VL-Embedding-8B.
63
 
64
- Uses last-token pooling per official Qwen3-VL-Embedding implementation.
65
- Loaded on-demand when MOCK_MODELS=false.
66
 
67
- Reference: https://github.com/QwenLM/Qwen3-VL-Embedding
68
  """
69
 
70
  EMBEDDING_DIM = 4096 # Per Qwen3-VL-Embedding-8B hidden_size
71
 
72
- def __init__(self):
73
- self.model = None
74
- self.tokenizer = None
75
-
76
- def _load_model(self):
77
- """Lazy load the embedding model."""
78
- if self.model is not None:
79
- return
80
-
81
- import torch
82
- from transformers import AutoModel, AutoTokenizer
83
-
84
- model_name = "Qwen/Qwen3-VL-Embedding-8B"
85
- logger.info(f"Loading embedding model: {model_name}")
86
-
87
- self.tokenizer = AutoTokenizer.from_pretrained(
88
- model_name,
89
- trust_remote_code=True,
90
- )
91
- self.model = AutoModel.from_pretrained(
92
- model_name,
93
- torch_dtype=torch.bfloat16,
94
- device_map="auto",
95
- trust_remote_code=True,
96
- )
97
- self.model.eval()
98
-
99
- @staticmethod
100
- def _pooling_last(hidden_state, attention_mask):
101
- """Extract the last valid token's hidden state.
102
-
103
- Official pooling method from Qwen3-VL-Embedding.
104
- Finds the last position where attention_mask == 1 and extracts that token.
105
- """
106
- import torch
107
-
108
- flipped_tensor = attention_mask.flip(dims=[1])
109
- last_one_positions = flipped_tensor.argmax(dim=1)
110
- col = attention_mask.shape[1] - last_one_positions - 1
111
- row = torch.arange(hidden_state.shape[0], device=hidden_state.device)
112
- return hidden_state[row, col]
113
-
114
  def __call__(self, input: list[str]) -> list[list[float]]:
115
- """Generate embeddings for a list of texts using last-token pooling."""
116
- self._load_model()
117
-
118
- import torch
119
-
120
- embeddings = []
121
- with torch.no_grad():
122
- for text in input:
123
- inputs = self.tokenizer(
124
- text,
125
- return_tensors="pt",
126
- truncation=True,
127
- max_length=512,
128
- padding=True,
129
- )
130
- # Note: With device_map="auto", transformers handles device routing internally
131
- # Do NOT call .to(device) - it breaks distributed models
132
 
133
- outputs = self.model(**inputs)
134
 
135
- # Use last-token pooling (official Qwen3-VL-Embedding method)
136
- attention_mask = inputs.get("attention_mask")
137
- if attention_mask is not None:
138
- embedding = self._pooling_last(outputs.last_hidden_state, attention_mask)
139
- else:
140
- # Fallback: use last token if no attention mask
141
- embedding = outputs.last_hidden_state[:, -1, :]
142
 
143
- # L2 normalize (per official implementation)
144
- embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)
145
- embeddings.append(embedding.squeeze().cpu().float().tolist())
146
-
147
- return embeddings
148
 
149
 
150
  def get_embedding_function():
151
- """Get appropriate embedding function based on settings."""
 
 
 
 
152
  if settings.mock_models:
153
  return MockEmbeddingFunction()
154
- return RealEmbeddingFunction()
155
 
156
 
157
  class ChromaVectorStore:
 
58
  return embedding
59
 
60
 
61
+ class SharedEmbeddingFunction:
62
+ """Embedding function that uses the shared model from RealModelStack.
63
 
64
+ This avoids loading a duplicate embedding model - instead uses the
65
+ model already loaded by the pipeline via model_stack.load_rag().
66
 
67
+ For ChromaDB compatibility, this wraps the model stack's embedding model.
68
  """
69
 
70
  EMBEDDING_DIM = 4096 # Per Qwen3-VL-Embedding-8B hidden_size
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def __call__(self, input: list[str]) -> list[list[float]]:
73
+ """Generate embeddings using the shared model from model stack."""
74
+ from models.loader import get_models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ model_stack = get_models()
77
 
78
+ # Check if RAG models are loaded
79
+ if not model_stack.is_rag_loaded():
80
+ logger.warning("RAG models not loaded yet - embeddings may fail")
81
+ # Return zero vectors as fallback
82
+ return [[0.0] * self.EMBEDDING_DIM for _ in input]
 
 
83
 
84
+ # Use the shared embedding model
85
+ return model_stack.embedding.embed_batch(input)
 
 
 
86
 
87
 
88
  def get_embedding_function():
89
+ """Get appropriate embedding function based on settings.
90
+
91
+ For real models, uses SharedEmbeddingFunction which wraps the
92
+ model stack's embedding model (no duplicate loading).
93
+ """
94
  if settings.mock_models:
95
  return MockEmbeddingFunction()
96
+ return SharedEmbeddingFunction()
97
 
98
 
99
  class ChromaVectorStore: