Spaces:
Paused
Implement lazy model loading to prevent CUDA OOM on 4xL4 GPUs
Browse filesProblem: All 3 models (~92GB) loaded at startup exceeded 88GB VRAM.
Solution: Sequential loading - vision model during Stage 2, RAG models
during Stage 3+. Vision is unloaded before RAG loads. Peak: ~60GB.
Changes:
- models/real.py: Add load_vision(), unload_vision(), load_rag() with
proper hook removal per HuggingFace accelerate docs
- models/loader.py: Real models now use lazy loading (no load_all)
- pipeline/main.py: Load/unload at appropriate pipeline stages
- rag/vectorstore.py: Use SharedEmbeddingFunction (no duplicate load)
- rag/retriever.py: Use SharedReranker (no duplicate load)
- models/mock.py: Add is_vision_loaded(), is_rag_loaded() for API parity
Memory profile:
- Phase A (Vision): 30B model ~60GB
- Transition: Unload + gc + empty_cache
- Phase B (RAG): 8B + 8B ~32GB
- Peak never exceeds 60GB (fits in 88GB)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- models/loader.py +26 -6
- models/mock.py +14 -1
- models/real.py +134 -19
- pipeline/main.py +15 -0
- rag/retriever.py +21 -55
- rag/vectorstore.py +21 -79
|
@@ -1,4 +1,14 @@
|
|
| 1 |
-
"""Model loading with mock/real switching based on environment.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
import time
|
|
@@ -16,7 +26,11 @@ _model_stack: ModelStack | None = None
|
|
| 16 |
|
| 17 |
|
| 18 |
def get_model_stack() -> ModelStack:
|
| 19 |
-
"""Get model stack based on environment configuration.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
start_time = time.time()
|
| 21 |
|
| 22 |
if settings.mock_models:
|
|
@@ -28,20 +42,26 @@ def get_model_stack() -> ModelStack:
|
|
| 28 |
logger.info(f"Mock model stack loaded in {elapsed:.2f}s")
|
| 29 |
return stack
|
| 30 |
else:
|
| 31 |
-
logger.info("
|
| 32 |
logger.info(f"Vision model: {settings.vision_model}")
|
| 33 |
logger.info(f"Embedding model: {settings.embedding_model}")
|
| 34 |
logger.info(f"Reranker model: {settings.reranker_model}")
|
|
|
|
| 35 |
from models.real import RealModelStack
|
| 36 |
|
| 37 |
-
|
|
|
|
| 38 |
elapsed = time.time() - start_time
|
| 39 |
-
logger.info(f"Real model stack
|
| 40 |
return stack
|
| 41 |
|
| 42 |
|
| 43 |
def get_models() -> ModelStack:
|
| 44 |
-
"""Get or create the singleton model stack.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
global _model_stack
|
| 46 |
if _model_stack is None:
|
| 47 |
logger.debug("Model stack not initialized, creating new stack")
|
|
|
|
| 1 |
+
"""Model loading with mock/real switching based on environment.
|
| 2 |
+
|
| 3 |
+
Supports two loading modes:
|
| 4 |
+
- MOCK_MODELS=true: Loads all mock models at startup (fast, for local dev)
|
| 5 |
+
- MOCK_MODELS=false: Uses LAZY LOADING (models loaded on-demand by pipeline)
|
| 6 |
+
|
| 7 |
+
Lazy Loading Strategy (for 4xL4 GPUs with 88GB total):
|
| 8 |
+
- Vision 30B (~60GB) loaded before Stage 2, unloaded after
|
| 9 |
+
- RAG models (~32GB) loaded before Stage 3
|
| 10 |
+
- Peak usage ~60GB, never both simultaneously
|
| 11 |
+
"""
|
| 12 |
|
| 13 |
import logging
|
| 14 |
import time
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
def get_model_stack() -> ModelStack:
|
| 29 |
+
"""Get model stack based on environment configuration.
|
| 30 |
+
|
| 31 |
+
For mock models: Loads all models immediately (fast, for local dev).
|
| 32 |
+
For real models: Returns uninitialized stack for lazy loading.
|
| 33 |
+
"""
|
| 34 |
start_time = time.time()
|
| 35 |
|
| 36 |
if settings.mock_models:
|
|
|
|
| 42 |
logger.info(f"Mock model stack loaded in {elapsed:.2f}s")
|
| 43 |
return stack
|
| 44 |
else:
|
| 45 |
+
logger.info("Creating REAL model stack (production mode - lazy loading)")
|
| 46 |
logger.info(f"Vision model: {settings.vision_model}")
|
| 47 |
logger.info(f"Embedding model: {settings.embedding_model}")
|
| 48 |
logger.info(f"Reranker model: {settings.reranker_model}")
|
| 49 |
+
logger.info("NOTE: Models will be loaded on-demand by pipeline stages")
|
| 50 |
from models.real import RealModelStack
|
| 51 |
|
| 52 |
+
# Don't load models yet - pipeline will call load_vision() and load_rag()
|
| 53 |
+
stack = RealModelStack()
|
| 54 |
elapsed = time.time() - start_time
|
| 55 |
+
logger.info(f"Real model stack initialized in {elapsed:.2f}s (no models loaded yet)")
|
| 56 |
return stack
|
| 57 |
|
| 58 |
|
| 59 |
def get_models() -> ModelStack:
|
| 60 |
+
"""Get or create the singleton model stack.
|
| 61 |
+
|
| 62 |
+
For real models, this returns an uninitialized stack.
|
| 63 |
+
Call stack.load_vision() or stack.load_rag() as needed.
|
| 64 |
+
"""
|
| 65 |
global _model_stack
|
| 66 |
if _model_stack is None:
|
| 67 |
logger.debug("Model stack not initialized, creating new stack")
|
|
@@ -186,7 +186,12 @@ class MockRerankerModel:
|
|
| 186 |
|
| 187 |
|
| 188 |
class MockModelStack:
|
| 189 |
-
"""Mock model stack for local development.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
def __init__(self):
|
| 192 |
self.vision = MockVisionModel()
|
|
@@ -207,3 +212,11 @@ class MockModelStack:
|
|
| 207 |
def is_loaded(self) -> bool:
|
| 208 |
"""Check if models are loaded."""
|
| 209 |
return self.loaded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
|
| 188 |
class MockModelStack:
|
| 189 |
+
"""Mock model stack for local development.
|
| 190 |
+
|
| 191 |
+
Unlike RealModelStack, mock models are always loaded together.
|
| 192 |
+
The is_vision_loaded() and is_rag_loaded() methods are provided
|
| 193 |
+
for API compatibility with the lazy loading pipeline.
|
| 194 |
+
"""
|
| 195 |
|
| 196 |
def __init__(self):
|
| 197 |
self.vision = MockVisionModel()
|
|
|
|
| 212 |
def is_loaded(self) -> bool:
|
| 213 |
"""Check if models are loaded."""
|
| 214 |
return self.loaded
|
| 215 |
+
|
| 216 |
+
def is_vision_loaded(self) -> bool:
|
| 217 |
+
"""Check if vision model is loaded (always True when loaded)."""
|
| 218 |
+
return self.loaded
|
| 219 |
+
|
| 220 |
+
def is_rag_loaded(self) -> bool:
|
| 221 |
+
"""Check if RAG models are loaded (always True when loaded)."""
|
| 222 |
+
return self.loaded
|
|
@@ -1,7 +1,13 @@
|
|
| 1 |
"""Real model loading for production (HuggingFace Spaces with 4xL4 GPUs).
|
| 2 |
|
| 3 |
This module loads the actual Qwen3-VL models for production use.
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
Model Loading:
|
| 7 |
- Vision: Qwen3VLMoeForConditionalGeneration (standard transformers)
|
|
@@ -9,6 +15,7 @@ Model Loading:
|
|
| 9 |
- Reranker: Qwen3VLReranker (official scripts from QwenLM/Qwen3-VL-Embedding)
|
| 10 |
"""
|
| 11 |
|
|
|
|
| 12 |
import json
|
| 13 |
import logging
|
| 14 |
import re
|
|
@@ -24,27 +31,48 @@ logger = logging.getLogger(__name__)
|
|
| 24 |
|
| 25 |
|
| 26 |
class RealModelStack:
|
| 27 |
-
"""Real model stack for production on HuggingFace Spaces.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def __init__(self):
|
| 30 |
self.models: dict[str, Any] = {}
|
| 31 |
self.processors: dict[str, Any] = {}
|
| 32 |
-
self.
|
| 33 |
-
|
| 34 |
-
def load_all(self) -> "RealModelStack":
|
| 35 |
-
"""Load all models with device_map='auto' for multi-GPU distribution."""
|
| 36 |
-
from transformers import AutoProcessor
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
if torch.cuda.is_available():
|
| 41 |
gpu_count = torch.cuda.device_count()
|
| 42 |
-
logger.info(f"
|
| 43 |
for i in range(gpu_count):
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
# Vision model (~58GB in BF16)
|
| 48 |
logger.info(f"Loading vision model: {settings.vision_model}")
|
| 49 |
vision_start = time.time()
|
| 50 |
try:
|
|
@@ -64,6 +92,8 @@ class RealModelStack:
|
|
| 64 |
except Exception as e:
|
| 65 |
logger.warning(f"Failed to load 30B vision model: {e}")
|
| 66 |
logger.info(f"Falling back to {settings.vision_model_fallback}")
|
|
|
|
|
|
|
| 67 |
self.models["vision"] = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
| 68 |
settings.vision_model_fallback,
|
| 69 |
torch_dtype=torch.bfloat16,
|
|
@@ -76,6 +106,66 @@ class RealModelStack:
|
|
| 76 |
)
|
| 77 |
logger.info(f"Fallback vision model loaded in {time.time() - vision_start:.2f}s")
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
# Embedding model (~16GB in BF16) - Using official Qwen3VLEmbedder
|
| 80 |
logger.info(f"Loading embedding model: {settings.embedding_model}")
|
| 81 |
embed_start = time.time()
|
|
@@ -85,7 +175,6 @@ class RealModelStack:
|
|
| 85 |
model_name_or_path=settings.embedding_model,
|
| 86 |
torch_dtype=torch.bfloat16,
|
| 87 |
)
|
| 88 |
-
# Processor is internal to Qwen3VLEmbedder, but store reference for compatibility
|
| 89 |
self.processors["embedding"] = self.models["embedding"].processor
|
| 90 |
logger.info(f"Embedding model loaded in {time.time() - embed_start:.2f}s")
|
| 91 |
|
|
@@ -98,31 +187,57 @@ class RealModelStack:
|
|
| 98 |
model_name_or_path=settings.reranker_model,
|
| 99 |
torch_dtype=torch.bfloat16,
|
| 100 |
)
|
| 101 |
-
# Processor is internal to Qwen3VLReranker, but store reference for compatibility
|
| 102 |
self.processors["reranker"] = self.models["reranker"].processor
|
| 103 |
logger.info(f"Reranker model loaded in {time.time() - reranker_start:.2f}s")
|
| 104 |
|
| 105 |
-
self.
|
| 106 |
-
logger.info("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
return self
|
| 108 |
|
| 109 |
def is_loaded(self) -> bool:
|
| 110 |
-
"""Check if models are loaded."""
|
| 111 |
-
return self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
@property
|
| 114 |
def vision(self) -> "RealVisionModel":
|
| 115 |
"""Return vision model wrapped for pipeline consumption."""
|
|
|
|
|
|
|
| 116 |
return RealVisionModel(self.models["vision"], self.processors["vision"])
|
| 117 |
|
| 118 |
@property
|
| 119 |
def embedding(self) -> "RealEmbeddingModel":
|
| 120 |
"""Return embedding model wrapped for pipeline consumption."""
|
|
|
|
|
|
|
| 121 |
return RealEmbeddingModel(self.models["embedding"], self.processors["embedding"])
|
| 122 |
|
| 123 |
@property
|
| 124 |
def reranker(self) -> "RealRerankerModel":
|
| 125 |
"""Return reranker model wrapped for pipeline consumption."""
|
|
|
|
|
|
|
| 126 |
return RealRerankerModel(self.models["reranker"], self.processors["reranker"])
|
| 127 |
|
| 128 |
|
|
|
|
| 1 |
"""Real model loading for production (HuggingFace Spaces with 4xL4 GPUs).
|
| 2 |
|
| 3 |
This module loads the actual Qwen3-VL models for production use.
|
| 4 |
+
Uses LAZY LOADING to fit within 88GB VRAM (4xL4 with ~22GB each).
|
| 5 |
+
|
| 6 |
+
Memory Strategy:
|
| 7 |
+
- Vision 30B (~60GB): Loaded ONLY during Stage 2 (Vision Analysis)
|
| 8 |
+
- Embedding 8B (~16GB): Loaded ONLY during Stages 3+ (RAG)
|
| 9 |
+
- Reranker 8B (~16GB): Loaded ONLY during Stages 3+ (RAG)
|
| 10 |
+
- Peak usage: ~60GB (never all three simultaneously)
|
| 11 |
|
| 12 |
Model Loading:
|
| 13 |
- Vision: Qwen3VLMoeForConditionalGeneration (standard transformers)
|
|
|
|
| 15 |
- Reranker: Qwen3VLReranker (official scripts from QwenLM/Qwen3-VL-Embedding)
|
| 16 |
"""
|
| 17 |
|
| 18 |
+
import gc
|
| 19 |
import json
|
| 20 |
import logging
|
| 21 |
import re
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
class RealModelStack:
|
| 34 |
+
"""Real model stack for production on HuggingFace Spaces.
|
| 35 |
+
|
| 36 |
+
Uses LAZY LOADING to prevent OOM errors on 4xL4 (88GB total):
|
| 37 |
+
- Vision 30B (~60GB) and RAG models (~32GB) are never loaded simultaneously
|
| 38 |
+
- Pipeline calls load_vision() before Stage 2, unload_vision() after
|
| 39 |
+
- Pipeline calls load_rag() before Stage 3
|
| 40 |
+
"""
|
| 41 |
|
| 42 |
def __init__(self):
|
| 43 |
self.models: dict[str, Any] = {}
|
| 44 |
self.processors: dict[str, Any] = {}
|
| 45 |
+
self._vision_loaded = False
|
| 46 |
+
self._rag_loaded = False
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
def _log_gpu_status(self):
|
| 49 |
+
"""Log current GPU memory status."""
|
| 50 |
if torch.cuda.is_available():
|
| 51 |
gpu_count = torch.cuda.device_count()
|
| 52 |
+
logger.info(f"GPU memory status ({gpu_count} devices):")
|
| 53 |
for i in range(gpu_count):
|
| 54 |
+
total = torch.cuda.get_device_properties(i).total_memory / (1024**3)
|
| 55 |
+
allocated = torch.cuda.memory_allocated(i) / (1024**3)
|
| 56 |
+
cached = torch.cuda.memory_reserved(i) / (1024**3)
|
| 57 |
+
free = total - allocated
|
| 58 |
+
logger.info(f" GPU {i}: {allocated:.1f}GB allocated, {cached:.1f}GB cached, {free:.1f}GB free / {total:.1f}GB total")
|
| 59 |
+
|
| 60 |
+
def load_vision(self) -> "RealModelStack":
|
| 61 |
+
"""Load only the vision model (~60GB in BF16).
|
| 62 |
+
|
| 63 |
+
Call this before Stage 2 (Vision Analysis).
|
| 64 |
+
Must call unload_vision() before load_rag() to free memory.
|
| 65 |
+
"""
|
| 66 |
+
if self._vision_loaded:
|
| 67 |
+
logger.debug("Vision model already loaded, skipping")
|
| 68 |
+
return self
|
| 69 |
+
|
| 70 |
+
from transformers import AutoProcessor
|
| 71 |
+
|
| 72 |
+
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 73 |
+
logger.info(f"Loading vision model on {device_type}")
|
| 74 |
+
self._log_gpu_status()
|
| 75 |
|
|
|
|
| 76 |
logger.info(f"Loading vision model: {settings.vision_model}")
|
| 77 |
vision_start = time.time()
|
| 78 |
try:
|
|
|
|
| 92 |
except Exception as e:
|
| 93 |
logger.warning(f"Failed to load 30B vision model: {e}")
|
| 94 |
logger.info(f"Falling back to {settings.vision_model_fallback}")
|
| 95 |
+
from transformers import Qwen3VLMoeForConditionalGeneration
|
| 96 |
+
|
| 97 |
self.models["vision"] = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
| 98 |
settings.vision_model_fallback,
|
| 99 |
torch_dtype=torch.bfloat16,
|
|
|
|
| 106 |
)
|
| 107 |
logger.info(f"Fallback vision model loaded in {time.time() - vision_start:.2f}s")
|
| 108 |
|
| 109 |
+
self._vision_loaded = True
|
| 110 |
+
self._log_gpu_status()
|
| 111 |
+
return self
|
| 112 |
+
|
| 113 |
+
def unload_vision(self):
|
| 114 |
+
"""Unload vision model and free CUDA memory.
|
| 115 |
+
|
| 116 |
+
Uses accelerate's remove_hook_from_module per HuggingFace docs.
|
| 117 |
+
Call this after Stage 2 (Vision Analysis) to free memory for RAG.
|
| 118 |
+
"""
|
| 119 |
+
if not self._vision_loaded or "vision" not in self.models:
|
| 120 |
+
logger.debug("Vision model not loaded, skipping unload")
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
logger.info("Unloading vision model to free memory for RAG...")
|
| 124 |
+
self._log_gpu_status()
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
from accelerate.hooks import remove_hook_from_module
|
| 128 |
+
|
| 129 |
+
# CRITICAL: Remove hooks before deleting (required for device_map="auto")
|
| 130 |
+
model = self.models["vision"]
|
| 131 |
+
if hasattr(model, 'model'):
|
| 132 |
+
# Some wrappers have nested model
|
| 133 |
+
remove_hook_from_module(model.model, recurse=True)
|
| 134 |
+
remove_hook_from_module(model, recurse=True)
|
| 135 |
+
logger.debug("Accelerate hooks removed from vision model")
|
| 136 |
+
except ImportError:
|
| 137 |
+
logger.warning("accelerate.hooks not available, proceeding with basic cleanup")
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.warning(f"Hook removal failed (continuing anyway): {e}")
|
| 140 |
+
|
| 141 |
+
# Delete model and processor
|
| 142 |
+
del self.models["vision"]
|
| 143 |
+
del self.processors["vision"]
|
| 144 |
+
self._vision_loaded = False
|
| 145 |
+
|
| 146 |
+
# Clear CUDA cache (may not free 100% but sufficient for sequential loading)
|
| 147 |
+
gc.collect()
|
| 148 |
+
torch.cuda.empty_cache()
|
| 149 |
+
|
| 150 |
+
logger.info("Vision model unloaded, CUDA cache cleared")
|
| 151 |
+
self._log_gpu_status()
|
| 152 |
+
|
| 153 |
+
def load_rag(self) -> "RealModelStack":
|
| 154 |
+
"""Load embedding and reranker models (~32GB total in BF16).
|
| 155 |
+
|
| 156 |
+
Call this before Stage 3 (RAG Retrieval).
|
| 157 |
+
Must call unload_vision() first to have enough memory.
|
| 158 |
+
"""
|
| 159 |
+
if self._rag_loaded:
|
| 160 |
+
logger.debug("RAG models already loaded, skipping")
|
| 161 |
+
return self
|
| 162 |
+
|
| 163 |
+
if self._vision_loaded:
|
| 164 |
+
logger.warning("Vision model still loaded! Call unload_vision() first to avoid OOM.")
|
| 165 |
+
|
| 166 |
+
logger.info("Loading RAG models (embedding + reranker)...")
|
| 167 |
+
self._log_gpu_status()
|
| 168 |
+
|
| 169 |
# Embedding model (~16GB in BF16) - Using official Qwen3VLEmbedder
|
| 170 |
logger.info(f"Loading embedding model: {settings.embedding_model}")
|
| 171 |
embed_start = time.time()
|
|
|
|
| 175 |
model_name_or_path=settings.embedding_model,
|
| 176 |
torch_dtype=torch.bfloat16,
|
| 177 |
)
|
|
|
|
| 178 |
self.processors["embedding"] = self.models["embedding"].processor
|
| 179 |
logger.info(f"Embedding model loaded in {time.time() - embed_start:.2f}s")
|
| 180 |
|
|
|
|
| 187 |
model_name_or_path=settings.reranker_model,
|
| 188 |
torch_dtype=torch.bfloat16,
|
| 189 |
)
|
|
|
|
| 190 |
self.processors["reranker"] = self.models["reranker"].processor
|
| 191 |
logger.info(f"Reranker model loaded in {time.time() - reranker_start:.2f}s")
|
| 192 |
|
| 193 |
+
self._rag_loaded = True
|
| 194 |
+
logger.info("RAG models loaded successfully")
|
| 195 |
+
self._log_gpu_status()
|
| 196 |
+
return self
|
| 197 |
+
|
| 198 |
+
def load_all(self) -> "RealModelStack":
|
| 199 |
+
"""Load all models (DEPRECATED - use lazy loading instead).
|
| 200 |
+
|
| 201 |
+
This method is kept for backward compatibility but will cause OOM
|
| 202 |
+
on 4xL4 GPUs. Use load_vision() and load_rag() sequentially instead.
|
| 203 |
+
"""
|
| 204 |
+
logger.warning("load_all() is deprecated - use load_vision() and load_rag() for lazy loading")
|
| 205 |
+
self.load_vision()
|
| 206 |
+
# Note: This WILL cause OOM on 4xL4 as vision (60GB) + RAG (32GB) > 88GB
|
| 207 |
+
self.load_rag()
|
| 208 |
return self
|
| 209 |
|
| 210 |
def is_loaded(self) -> bool:
|
| 211 |
+
"""Check if any models are loaded."""
|
| 212 |
+
return self._vision_loaded or self._rag_loaded
|
| 213 |
+
|
| 214 |
+
def is_vision_loaded(self) -> bool:
|
| 215 |
+
"""Check if vision model is loaded."""
|
| 216 |
+
return self._vision_loaded
|
| 217 |
+
|
| 218 |
+
def is_rag_loaded(self) -> bool:
|
| 219 |
+
"""Check if RAG models are loaded."""
|
| 220 |
+
return self._rag_loaded
|
| 221 |
|
| 222 |
@property
|
| 223 |
def vision(self) -> "RealVisionModel":
|
| 224 |
"""Return vision model wrapped for pipeline consumption."""
|
| 225 |
+
if not self._vision_loaded:
|
| 226 |
+
raise RuntimeError("Vision model not loaded. Call load_vision() first.")
|
| 227 |
return RealVisionModel(self.models["vision"], self.processors["vision"])
|
| 228 |
|
| 229 |
@property
|
| 230 |
def embedding(self) -> "RealEmbeddingModel":
|
| 231 |
"""Return embedding model wrapped for pipeline consumption."""
|
| 232 |
+
if not self._rag_loaded:
|
| 233 |
+
raise RuntimeError("Embedding model not loaded. Call load_rag() first.")
|
| 234 |
return RealEmbeddingModel(self.models["embedding"], self.processors["embedding"])
|
| 235 |
|
| 236 |
@property
|
| 237 |
def reranker(self) -> "RealRerankerModel":
|
| 238 |
"""Return reranker model wrapped for pipeline consumption."""
|
| 239 |
+
if not self._rag_loaded:
|
| 240 |
+
raise RuntimeError("Reranker model not loaded. Call load_rag() first.")
|
| 241 |
return RealRerankerModel(self.models["reranker"], self.processors["reranker"])
|
| 242 |
|
| 243 |
|
|
@@ -199,6 +199,11 @@ class FDAMPipeline:
|
|
| 199 |
logger.info(f"Stage 2/6: Vision Analysis ({len(session.images)} images)")
|
| 200 |
report_progress(2, "Analyzing images with AI...")
|
| 201 |
model_stack = get_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
vision_results = {}
|
| 203 |
annotated_images = []
|
| 204 |
room_mapping = {}
|
|
@@ -259,10 +264,20 @@ class FDAMPipeline:
|
|
| 259 |
logger.info(f"Stage 2 completed in {time.time() - stage_start:.2f}s: "
|
| 260 |
f"{len(vision_results)} images analyzed")
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
# Stage 3: RAG Retrieval
|
| 263 |
stage_start = time.time()
|
| 264 |
logger.info("Stage 3/6: RAG Retrieval")
|
| 265 |
report_progress(3, "Retrieving FDAM methodology context...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
# RAG is integrated into disposition engine, just verify connection
|
| 267 |
try:
|
| 268 |
test_results = self.retriever.retrieve("test connection", top_k=1)
|
|
|
|
| 199 |
logger.info(f"Stage 2/6: Vision Analysis ({len(session.images)} images)")
|
| 200 |
report_progress(2, "Analyzing images with AI...")
|
| 201 |
model_stack = get_models()
|
| 202 |
+
|
| 203 |
+
# Lazy load vision model (for real models only - mock models are already loaded)
|
| 204 |
+
if hasattr(model_stack, 'load_vision') and not model_stack.is_vision_loaded():
|
| 205 |
+
logger.info("Lazy loading vision model...")
|
| 206 |
+
model_stack.load_vision()
|
| 207 |
vision_results = {}
|
| 208 |
annotated_images = []
|
| 209 |
room_mapping = {}
|
|
|
|
| 264 |
logger.info(f"Stage 2 completed in {time.time() - stage_start:.2f}s: "
|
| 265 |
f"{len(vision_results)} images analyzed")
|
| 266 |
|
| 267 |
+
# Unload vision model to free memory for RAG (for real models only)
|
| 268 |
+
if hasattr(model_stack, 'unload_vision') and model_stack.is_vision_loaded():
|
| 269 |
+
logger.info("Unloading vision model to free memory for RAG...")
|
| 270 |
+
model_stack.unload_vision()
|
| 271 |
+
|
| 272 |
# Stage 3: RAG Retrieval
|
| 273 |
stage_start = time.time()
|
| 274 |
logger.info("Stage 3/6: RAG Retrieval")
|
| 275 |
report_progress(3, "Retrieving FDAM methodology context...")
|
| 276 |
+
|
| 277 |
+
# Lazy load RAG models (for real models only - mock models are already loaded)
|
| 278 |
+
if hasattr(model_stack, 'load_rag') and not model_stack.is_rag_loaded():
|
| 279 |
+
logger.info("Lazy loading RAG models (embedding + reranker)...")
|
| 280 |
+
model_stack.load_rag()
|
| 281 |
# RAG is integrated into disposition engine, just verify connection
|
| 282 |
try:
|
| 283 |
test_results = self.retriever.retrieve("test connection", top_k=1)
|
|
@@ -84,84 +84,50 @@ class MockReranker:
|
|
| 84 |
return scores
|
| 85 |
|
| 86 |
|
| 87 |
-
class
|
| 88 |
-
"""
|
| 89 |
|
| 90 |
-
|
|
|
|
| 91 |
"""
|
| 92 |
|
| 93 |
-
def __init__(self):
|
| 94 |
-
self.model = None
|
| 95 |
-
self.tokenizer = None
|
| 96 |
-
|
| 97 |
-
def _load_model(self):
|
| 98 |
-
"""Lazy load the reranker model."""
|
| 99 |
-
if self.model is not None:
|
| 100 |
-
return
|
| 101 |
-
|
| 102 |
-
import torch
|
| 103 |
-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 104 |
-
|
| 105 |
-
model_name = "Qwen/Qwen3-VL-Reranker-8B"
|
| 106 |
-
logger.info(f"Loading reranker model: {model_name}")
|
| 107 |
-
|
| 108 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 109 |
-
model_name,
|
| 110 |
-
trust_remote_code=True,
|
| 111 |
-
)
|
| 112 |
-
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 113 |
-
model_name,
|
| 114 |
-
torch_dtype=torch.bfloat16,
|
| 115 |
-
device_map="auto",
|
| 116 |
-
trust_remote_code=True,
|
| 117 |
-
)
|
| 118 |
-
self.model.eval()
|
| 119 |
-
|
| 120 |
def rerank(
|
| 121 |
self,
|
| 122 |
query: str,
|
| 123 |
documents: list[str],
|
| 124 |
) -> list[float]:
|
| 125 |
-
"""Score documents using the reranker model.
|
| 126 |
|
| 127 |
Args:
|
| 128 |
query: Query text
|
| 129 |
documents: List of document texts
|
| 130 |
|
| 131 |
Returns:
|
| 132 |
-
List of scores for each document
|
| 133 |
"""
|
| 134 |
-
|
| 135 |
|
| 136 |
-
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
doc,
|
| 144 |
-
return_tensors="pt",
|
| 145 |
-
truncation=True,
|
| 146 |
-
max_length=512,
|
| 147 |
-
padding=True,
|
| 148 |
-
)
|
| 149 |
-
# Note: With device_map="auto", transformers handles device routing internally
|
| 150 |
-
# Do NOT call .to(device) - it breaks distributed models
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
score = torch.sigmoid(outputs.logits).squeeze().item()
|
| 155 |
-
scores.append(score)
|
| 156 |
-
|
| 157 |
-
return scores
|
| 158 |
|
| 159 |
|
| 160 |
def get_reranker():
|
| 161 |
-
"""Get appropriate reranker based on settings.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
if settings.mock_models:
|
| 163 |
return MockReranker()
|
| 164 |
-
return
|
| 165 |
|
| 166 |
|
| 167 |
class FDAMRetriever:
|
|
|
|
| 84 |
return scores
|
| 85 |
|
| 86 |
|
| 87 |
+
class SharedReranker:
|
| 88 |
+
"""Reranker that uses the shared model from RealModelStack.
|
| 89 |
|
| 90 |
+
This avoids loading a duplicate reranker model - instead uses the
|
| 91 |
+
model already loaded by the pipeline via model_stack.load_rag().
|
| 92 |
"""
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
def rerank(
|
| 95 |
self,
|
| 96 |
query: str,
|
| 97 |
documents: list[str],
|
| 98 |
) -> list[float]:
|
| 99 |
+
"""Score documents using the shared reranker model.
|
| 100 |
|
| 101 |
Args:
|
| 102 |
query: Query text
|
| 103 |
documents: List of document texts
|
| 104 |
|
| 105 |
Returns:
|
| 106 |
+
List of scores (0-1) for each document
|
| 107 |
"""
|
| 108 |
+
from models.loader import get_models
|
| 109 |
|
| 110 |
+
model_stack = get_models()
|
| 111 |
|
| 112 |
+
# Check if RAG models are loaded
|
| 113 |
+
if not model_stack.is_rag_loaded():
|
| 114 |
+
logger.warning("RAG models not loaded yet - reranking may fail")
|
| 115 |
+
# Return neutral scores as fallback
|
| 116 |
+
return [0.5] * len(documents)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
# Use the shared reranker model
|
| 119 |
+
return model_stack.reranker.rerank(query, documents)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
def get_reranker():
|
| 123 |
+
"""Get appropriate reranker based on settings.
|
| 124 |
+
|
| 125 |
+
For real models, uses SharedReranker which wraps the
|
| 126 |
+
model stack's reranker model (no duplicate loading).
|
| 127 |
+
"""
|
| 128 |
if settings.mock_models:
|
| 129 |
return MockReranker()
|
| 130 |
+
return SharedReranker()
|
| 131 |
|
| 132 |
|
| 133 |
class FDAMRetriever:
|
|
@@ -58,100 +58,42 @@ class MockEmbeddingFunction:
|
|
| 58 |
return embedding
|
| 59 |
|
| 60 |
|
| 61 |
-
class
|
| 62 |
-
"""
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
|
| 67 |
-
|
| 68 |
"""
|
| 69 |
|
| 70 |
EMBEDDING_DIM = 4096 # Per Qwen3-VL-Embedding-8B hidden_size
|
| 71 |
|
| 72 |
-
def __init__(self):
|
| 73 |
-
self.model = None
|
| 74 |
-
self.tokenizer = None
|
| 75 |
-
|
| 76 |
-
def _load_model(self):
|
| 77 |
-
"""Lazy load the embedding model."""
|
| 78 |
-
if self.model is not None:
|
| 79 |
-
return
|
| 80 |
-
|
| 81 |
-
import torch
|
| 82 |
-
from transformers import AutoModel, AutoTokenizer
|
| 83 |
-
|
| 84 |
-
model_name = "Qwen/Qwen3-VL-Embedding-8B"
|
| 85 |
-
logger.info(f"Loading embedding model: {model_name}")
|
| 86 |
-
|
| 87 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 88 |
-
model_name,
|
| 89 |
-
trust_remote_code=True,
|
| 90 |
-
)
|
| 91 |
-
self.model = AutoModel.from_pretrained(
|
| 92 |
-
model_name,
|
| 93 |
-
torch_dtype=torch.bfloat16,
|
| 94 |
-
device_map="auto",
|
| 95 |
-
trust_remote_code=True,
|
| 96 |
-
)
|
| 97 |
-
self.model.eval()
|
| 98 |
-
|
| 99 |
-
@staticmethod
|
| 100 |
-
def _pooling_last(hidden_state, attention_mask):
|
| 101 |
-
"""Extract the last valid token's hidden state.
|
| 102 |
-
|
| 103 |
-
Official pooling method from Qwen3-VL-Embedding.
|
| 104 |
-
Finds the last position where attention_mask == 1 and extracts that token.
|
| 105 |
-
"""
|
| 106 |
-
import torch
|
| 107 |
-
|
| 108 |
-
flipped_tensor = attention_mask.flip(dims=[1])
|
| 109 |
-
last_one_positions = flipped_tensor.argmax(dim=1)
|
| 110 |
-
col = attention_mask.shape[1] - last_one_positions - 1
|
| 111 |
-
row = torch.arange(hidden_state.shape[0], device=hidden_state.device)
|
| 112 |
-
return hidden_state[row, col]
|
| 113 |
-
|
| 114 |
def __call__(self, input: list[str]) -> list[list[float]]:
|
| 115 |
-
"""Generate embeddings
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
import torch
|
| 119 |
-
|
| 120 |
-
embeddings = []
|
| 121 |
-
with torch.no_grad():
|
| 122 |
-
for text in input:
|
| 123 |
-
inputs = self.tokenizer(
|
| 124 |
-
text,
|
| 125 |
-
return_tensors="pt",
|
| 126 |
-
truncation=True,
|
| 127 |
-
max_length=512,
|
| 128 |
-
padding=True,
|
| 129 |
-
)
|
| 130 |
-
# Note: With device_map="auto", transformers handles device routing internally
|
| 131 |
-
# Do NOT call .to(device) - it breaks distributed models
|
| 132 |
|
| 133 |
-
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
# Fallback: use last token if no attention mask
|
| 141 |
-
embedding = outputs.last_hidden_state[:, -1, :]
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
embeddings.append(embedding.squeeze().cpu().float().tolist())
|
| 146 |
-
|
| 147 |
-
return embeddings
|
| 148 |
|
| 149 |
|
| 150 |
def get_embedding_function():
|
| 151 |
-
"""Get appropriate embedding function based on settings.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
if settings.mock_models:
|
| 153 |
return MockEmbeddingFunction()
|
| 154 |
-
return
|
| 155 |
|
| 156 |
|
| 157 |
class ChromaVectorStore:
|
|
|
|
| 58 |
return embedding
|
| 59 |
|
| 60 |
|
| 61 |
+
class SharedEmbeddingFunction:
|
| 62 |
+
"""Embedding function that uses the shared model from RealModelStack.
|
| 63 |
|
| 64 |
+
This avoids loading a duplicate embedding model - instead uses the
|
| 65 |
+
model already loaded by the pipeline via model_stack.load_rag().
|
| 66 |
|
| 67 |
+
For ChromaDB compatibility, this wraps the model stack's embedding model.
|
| 68 |
"""
|
| 69 |
|
| 70 |
EMBEDDING_DIM = 4096 # Per Qwen3-VL-Embedding-8B hidden_size
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
def __call__(self, input: list[str]) -> list[list[float]]:
|
| 73 |
+
"""Generate embeddings using the shared model from model stack."""
|
| 74 |
+
from models.loader import get_models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
model_stack = get_models()
|
| 77 |
|
| 78 |
+
# Check if RAG models are loaded
|
| 79 |
+
if not model_stack.is_rag_loaded():
|
| 80 |
+
logger.warning("RAG models not loaded yet - embeddings may fail")
|
| 81 |
+
# Return zero vectors as fallback
|
| 82 |
+
return [[0.0] * self.EMBEDDING_DIM for _ in input]
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
# Use the shared embedding model
|
| 85 |
+
return model_stack.embedding.embed_batch(input)
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
def get_embedding_function():
|
| 89 |
+
"""Get appropriate embedding function based on settings.
|
| 90 |
+
|
| 91 |
+
For real models, uses SharedEmbeddingFunction which wraps the
|
| 92 |
+
model stack's embedding model (no duplicate loading).
|
| 93 |
+
"""
|
| 94 |
if settings.mock_models:
|
| 95 |
return MockEmbeddingFunction()
|
| 96 |
+
return SharedEmbeddingFunction()
|
| 97 |
|
| 98 |
|
| 99 |
class ChromaVectorStore:
|