SmokeScan / models /real.py
KinetoLabs's picture
Add enforce_eager=True to fix KV cache memory issue
1b7fbd7
"""Real model loading for production (HuggingFace Spaces).
This module loads the production models:
- Vision: Qwen/Qwen3-VL-4B-Thinking (~10GB via vLLM, single GPU)
- Embedding: Qwen/Qwen3-VL-Embedding-2B (~4GB)
- Reranker: Qwen/Qwen3-VL-Reranker-2B (~4GB)
- Total: ~18GB on single L4 GPU (22GB)
Model Loading:
- Vision: vLLM with single GPU (no tensor parallelism needed)
- Embedding: Qwen3VLEmbedder (official scripts from QwenLM/Qwen3-VL-Embedding)
- Reranker: Qwen3VLReranker (official scripts from QwenLM/Qwen3-VL-Embedding)
"""
import os
# vLLM environment variables - MUST be set before importing vLLM
# Note: Using single GPU (TP=1) so NCCL workarounds are not needed
import json
import logging
import re
import time
import torch
from typing import Any
from PIL import Image
from config.inference import vision_config
from config.settings import settings
logger = logging.getLogger(__name__)
class RealModelStack:
"""Real model stack for production on HuggingFace Spaces.
Loads all 3 models at initialization (~18GB total on single GPU):
- Vision 4B via vLLM: ~10GB
- Embedding 2B: ~4GB
- Reranker 2B: ~4GB
"""
def __init__(self):
self.models: dict[str, Any] = {}
self.processors: dict[str, Any] = {}
self._loaded = False
def _log_gpu_status(self):
"""Log current GPU memory status."""
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
logger.info(f"GPU memory status ({gpu_count} devices):")
for i in range(gpu_count):
total = torch.cuda.get_device_properties(i).total_memory / (1024**3)
allocated = torch.cuda.memory_allocated(i) / (1024**3)
cached = torch.cuda.memory_reserved(i) / (1024**3)
free = total - allocated
logger.info(f" GPU {i}: {allocated:.1f}GB allocated, {cached:.1f}GB cached, {free:.1f}GB free / {total:.1f}GB total")
def load_all(self) -> "RealModelStack":
"""Load all models.
Loads FP8 vision model via vLLM and RAG models (Embedding + Reranker).
"""
if self._loaded:
logger.debug("Models already loaded, skipping")
return self
logger.info("Loading production models...")
self._log_gpu_status()
total_start = time.time()
# Vision model via vLLM (~10GB for 4B model)
logger.info(f"Loading vision model: {settings.vision_model}")
vision_start = time.time()
from vllm import LLM, SamplingParams
from transformers import AutoProcessor
self.models["vision"] = LLM(
model=settings.vision_model,
tensor_parallel_size=settings.vllm_tensor_parallel_size, # 1 for single GPU
trust_remote_code=True,
gpu_memory_utilization=0.55, # Leave ~10GB for embedding + reranker
max_model_len=8192, # Reduced to save KV cache memory
enforce_eager=True, # Skip torch.compile to reduce memory overhead
)
# Load processor for chat template formatting
self.processors["vision"] = AutoProcessor.from_pretrained(
settings.vision_model,
trust_remote_code=True,
)
# Store sampling params for inference
self.models["vision_sampling_params"] = SamplingParams(
max_tokens=vision_config.max_tokens,
temperature=vision_config.temperature,
top_p=vision_config.top_p,
top_k=vision_config.top_k,
repetition_penalty=vision_config.repetition_penalty,
)
logger.info(f"Vision model loaded in {time.time() - vision_start:.2f}s")
# Embedding model (~4GB in BF16) - Using official Qwen3VLEmbedder
logger.info(f"Loading embedding model: {settings.embedding_model}")
embed_start = time.time()
from scripts.qwen3_vl import Qwen3VLEmbedder
self.models["embedding"] = Qwen3VLEmbedder(
model_name_or_path=settings.embedding_model,
torch_dtype=torch.bfloat16,
)
self.processors["embedding"] = self.models["embedding"].processor
logger.info(f"Embedding model loaded in {time.time() - embed_start:.2f}s")
# Reranker model (~4GB in BF16) - Using official Qwen3VLReranker
logger.info(f"Loading reranker model: {settings.reranker_model}")
reranker_start = time.time()
from scripts.qwen3_vl import Qwen3VLReranker
self.models["reranker"] = Qwen3VLReranker(
model_name_or_path=settings.reranker_model,
torch_dtype=torch.bfloat16,
)
self.processors["reranker"] = self.models["reranker"].processor
logger.info(f"Reranker model loaded in {time.time() - reranker_start:.2f}s")
self._loaded = True
total_time = time.time() - total_start
logger.info(f"All models loaded in {total_time:.2f}s")
self._log_gpu_status()
return self
def is_loaded(self) -> bool:
"""Check if models are loaded."""
return self._loaded
@property
def vision(self) -> "VisionModel":
"""Return FP8 vision model wrapped for pipeline consumption."""
if not self._loaded:
raise RuntimeError("Models not loaded. Call load_all() first.")
return VisionModel(
model=self.models["vision"],
processor=self.processors["vision"],
sampling_params=self.models["vision_sampling_params"],
)
@property
def embedding(self) -> "RealEmbeddingModel":
"""Return embedding model wrapped for pipeline consumption."""
if not self._loaded:
raise RuntimeError("Models not loaded. Call load_all() first.")
return RealEmbeddingModel(self.models["embedding"], self.processors["embedding"])
@property
def reranker(self) -> "RealRerankerModel":
"""Return reranker model wrapped for pipeline consumption."""
if not self._loaded:
raise RuntimeError("Models not loaded. Call load_all() first.")
return RealRerankerModel(self.models["reranker"], self.processors["reranker"])
class VisionModel:
"""Vision model for fire damage analysis.
Uses Qwen/Qwen3-VL-4B-Thinking via vLLM for inference.
Reasoning-enhanced model handles analysis with extended thinking
and outputs structured JSON.
Pipeline: Image -> Thinking Model (reasoning + JSON) -> Output
"""
# System prompt for FDAM fire damage assessment
VISION_SYSTEM_PROMPT = """You are an expert industrial hygienist analyzing fire damage images for the FDAM (Fire Damage Assessment Methodology) framework.
## Your Task
Analyze the provided image and return a structured JSON response with fire damage assessment.
## Zone Classification Criteria
- **Burn Zone**: Direct fire involvement. Look for structural char, complete combustion, exposed/damaged structural elements.
- **Near-Field**: Adjacent to burn zone with heavy smoke/heat exposure. Look for heavy soot deposits, heat damage (warping, discoloration), strong visible contamination.
- **Far-Field**: Smoke migration without direct heat exposure. Look for light to moderate deposits, discoloration, no structural damage.
## Condition Assessment Criteria
- **Background**: No visible contamination; surfaces appear normal/clean.
- **Light**: Faint discoloration; minimal visible deposits; would show faint marks on white wipe test.
- **Moderate**: Visible film or deposits; clear contamination; surface color noticeably altered.
- **Heavy**: Thick deposits; surface texture obscured; heavy coating visible.
- **Structural Damage**: Physical damage requiring repair before cleaning (charring, warping, holes, collapse).
## Material Categories
- **Non-porous**: steel, concrete, glass, metal, CMU (concrete masonry unit)
- **Semi-porous**: painted drywall, sealed wood
- **Porous**: unpainted drywall, carpet, insulation, acoustic tile, upholstery
- **HVAC**: rigid ductwork, flexible ductwork
## Combustion Particle Visual Indicators
- **Soot**: Black/dark gray coating with oily/sticky appearance; fine uniform texture
- **Char**: Black angular fragments; visible wood grain or fibrous structure
- **Ash**: Gray/white powdery residue; crystalline appearance"""
# JSON output format prompt
JSON_FORMAT_PROMPT = """Analyze this fire damage image and return a JSON response with this exact structure:
{
"zone": {
"classification": "burn" | "near-field" | "far-field",
"confidence": 0.0-1.0,
"reasoning": "explanation"
},
"condition": {
"level": "background" | "light" | "moderate" | "heavy" | "structural-damage",
"confidence": 0.0-1.0,
"reasoning": "explanation"
},
"materials": [
{
"type": "material type",
"category": "non-porous" | "semi-porous" | "porous" | "hvac",
"confidence": 0.0-1.0,
"location_description": "where in image",
"bounding_box": {"x": 0.0-1.0, "y": 0.0-1.0, "width": 0.0-1.0, "height": 0.0-1.0}
}
],
"combustion_indicators": {
"soot_visible": true/false,
"soot_pattern": "description or null",
"char_visible": true/false,
"char_description": "description or null",
"ash_visible": true/false,
"ash_description": "description or null"
},
"structural_concerns": ["list of structural issues if any"],
"access_issues": ["list of access problems if any"],
"recommended_sampling_locations": [
{
"description": "where to sample",
"sample_type": "tape_lift" | "surface_wipe" | "air_sample",
"priority": "high" | "medium" | "low"
}
],
"flags_for_review": ["any items requiring human review"]
}
IMPORTANT: Return ONLY valid JSON, no additional text."""
def __init__(self, model, processor, sampling_params):
self.model = model
self.processor = processor
self.sampling_params = sampling_params
def analyze_image(self, image: Image.Image, context: str = "") -> dict[str, Any]:
"""Analyze an image using the FP8 vision model via vLLM.
Args:
image: PIL Image to analyze
context: Optional context string (room info, etc.)
Returns:
Structured dict with zone, condition, materials, etc.
"""
start_time = time.time()
logger.debug(f"Starting FP8 vision analysis (context: {len(context)} chars)")
try:
# Build messages in Qwen3-VL format
messages = self._build_messages(image, context)
# Apply chat template to format prompt correctly
prompt = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Generate response using vLLM multimodal API
# Per vLLM docs: pass PIL image directly in multi_modal_data dict
outputs = self.model.generate(
prompts=[{
"prompt": prompt,
"multi_modal_data": {"image": image}, # Single PIL image
}],
sampling_params=self.sampling_params,
)
response_text = outputs[0].outputs[0].text
# Parse JSON from response
result = self._parse_json_response(response_text)
# Log result summary
total_time = time.time() - start_time
zone = result.get("zone", {}).get("classification", "unknown")
zone_conf = result.get("zone", {}).get("confidence", 0)
condition = result.get("condition", {}).get("level", "unknown")
condition_conf = result.get("condition", {}).get("confidence", 0)
num_materials = len(result.get("materials", []))
logger.info(f"Vision analysis complete in {total_time:.2f}s: "
f"zone={zone} ({zone_conf:.2f}), condition={condition} ({condition_conf:.2f}), "
f"materials={num_materials}")
return result
except Exception as e:
logger.error(f"Vision analysis failed: {e}")
return self._get_fallback_response(str(e))
def _build_messages(self, image: Image.Image, context: str) -> list[dict]:
"""Build messages in Qwen3-VL format for chat template.
Qwen3-VL expects:
- System message with role="system"
- User message with mixed content [{"type": "image", ...}, {"type": "text", ...}]
"""
# Build user text content
user_text = self.JSON_FORMAT_PROMPT
if context:
user_text = f"Context: {context}\n\n{user_text}"
messages = [
{"role": "system", "content": self.VISION_SYSTEM_PROMPT},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": user_text},
],
},
]
return messages
def _parse_json_response(self, response: str) -> dict[str, Any]:
"""Parse JSON response from model."""
try:
# Try to extract JSON from response
json_match = re.search(r'\{[\s\S]*\}', response)
if json_match:
json_str = json_match.group()
return json.loads(json_str)
else:
logger.warning("No JSON found in response")
return self._get_fallback_response("No JSON in response")
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON: {e}")
return self._get_fallback_response(f"JSON parse error: {e}")
def _get_fallback_response(self, reason: str) -> dict[str, Any]:
"""Return fallback response when analysis fails."""
return {
"zone": {
"classification": "far-field",
"confidence": 0.3,
"reasoning": f"Fallback due to: {reason}",
},
"condition": {
"level": "light",
"confidence": 0.3,
"reasoning": f"Fallback due to: {reason}",
},
"materials": [
{
"type": "general-surface",
"category": "semi-porous",
"confidence": 0.3,
"location_description": "Unable to determine",
"bounding_box": {"x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0},
}
],
"combustion_indicators": {
"soot_visible": False,
"soot_pattern": None,
"char_visible": False,
"char_description": None,
"ash_visible": False,
"ash_description": None,
},
"structural_concerns": [],
"access_issues": [],
"recommended_sampling_locations": [],
"flags_for_review": [f"Analysis failed: {reason}"],
"_fallback_used": True,
}
class RealEmbeddingModel:
"""Wrapper for real embedding model inference.
Uses the official Qwen3VLEmbedder from QwenLM/Qwen3-VL-Embedding.
The model handles last-token pooling and L2 normalization internally.
Model: Qwen/Qwen3-VL-Embedding-2B (2048-dim output)
"""
def __init__(self, model, processor):
"""Initialize with Qwen3VLEmbedder instance.
Args:
model: Qwen3VLEmbedder instance (official loader)
processor: Processor (stored for compatibility, but model has its own)
"""
self.model = model
self.processor = processor
def embed(self, text: str) -> list[float]:
"""Generate embedding for text using official Qwen3VLEmbedder.
The official model.process() handles:
- Tokenization and preprocessing
- Last-token pooling
- L2 normalization
Args:
text: Input text to embed
Returns:
List of floats representing the embedding (2048-dim for 2B model)
"""
try:
# Use official process() API - expects list of dicts
inputs = [{"text": text}]
embeddings = self.model.process(inputs, normalize=True)
# embeddings is a tensor of shape (1, hidden_dim)
return embeddings[0].cpu().tolist()
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
# Return zero vector as fallback (2048-dim per Qwen3-VL-Embedding-2B)
hidden_size = getattr(self.model.model.config, "hidden_size", 2048)
return [0.0] * hidden_size
def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""Generate embeddings for a batch of texts.
Uses official batch processing for efficiency.
"""
try:
inputs = [{"text": text} for text in texts]
embeddings = self.model.process(inputs, normalize=True)
return [emb.cpu().tolist() for emb in embeddings]
except Exception as e:
logger.error(f"Batch embedding generation failed: {e}")
hidden_size = getattr(self.model.model.config, "hidden_size", 2048)
return [[0.0] * hidden_size for _ in texts]
class RealRerankerModel:
"""Wrapper for real reranker model inference.
Uses the official Qwen3VLReranker from QwenLM/Qwen3-VL-Embedding.
The model handles yes/no scoring internally via:
- Extracts "yes" and "no" token weights from the LM head
- Creates a binary linear layer: weight = yes_weight - no_weight
- Scores = sigmoid(linear(last_token_hidden_state))
Model: Qwen/Qwen3-VL-Reranker-2B
"""
def __init__(self, model, processor):
"""Initialize with Qwen3VLReranker instance.
Args:
model: Qwen3VLReranker instance (official loader)
processor: Processor (stored for compatibility, but model has its own)
"""
self.model = model
self.processor = processor
def rerank(self, query: str, documents: list[str]) -> list[float]:
"""Rerank documents by relevance to query using official Qwen3VLReranker.
The official model.process() handles:
- Proper message formatting
- Tokenization
- Yes/no scoring with LM head weights
- Sigmoid normalization
Args:
query: The search query
documents: List of documents to rerank
Returns:
List of relevance scores (0-1) for each document.
Higher scores indicate more relevant documents.
"""
if not documents:
return []
try:
# Use official process() API - expects dict with query and documents
inputs = {
"instruction": "Retrieve relevant documents for the query.",
"query": {"text": query},
"documents": [{"text": doc} for doc in documents],
}
scores = self.model.process(inputs)
return scores
except Exception as e:
logger.error(f"Reranking failed: {e}")
return [0.0] * len(documents)
def rerank_with_indices(
self, query: str, documents: list[str], top_k: int = None
) -> list[tuple[int, float]]:
"""Rerank and return sorted (index, score) tuples.
Args:
query: The search query
documents: List of documents to rerank
top_k: Optional limit on number of results
Returns:
List of (original_index, score) tuples, sorted by score descending
"""
scores = self.rerank(query, documents)
# Create (index, score) pairs and sort by score descending
indexed_scores = list(enumerate(scores))
indexed_scores.sort(key=lambda x: x[1], reverse=True)
if top_k is not None:
indexed_scores = indexed_scores[:top_k]
return indexed_scores