Spaces:
Paused
Paused
| """Real model loading for production (HuggingFace Spaces with 4xL4 GPUs). | |
| This module loads the actual Qwen3-VL models for production use. | |
| Requires ~90GB VRAM (4xL4 with 96GB total). | |
| Model Loading: | |
| - Vision: Qwen3VLMoeForConditionalGeneration (standard transformers) | |
| - Embedding: Qwen3VLEmbedder (official scripts from QwenLM/Qwen3-VL-Embedding) | |
| - Reranker: Qwen3VLReranker (official scripts from QwenLM/Qwen3-VL-Embedding) | |
| """ | |
| import json | |
| import logging | |
| import re | |
| import time | |
| import torch | |
| from typing import Any | |
| from PIL import Image | |
| from config.inference import vision_config | |
| from config.settings import settings | |
| logger = logging.getLogger(__name__) | |
| class RealModelStack: | |
| """Real model stack for production on HuggingFace Spaces.""" | |
| def __init__(self): | |
| self.models: dict[str, Any] = {} | |
| self.processors: dict[str, Any] = {} | |
| self.loaded = False | |
| def load_all(self) -> "RealModelStack": | |
| """Load all models with device_map='auto' for multi-GPU distribution.""" | |
| from transformers import AutoProcessor | |
| device_type = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| logger.info(f"Loading models on {device_type}") | |
| if torch.cuda.is_available(): | |
| gpu_count = torch.cuda.device_count() | |
| logger.info(f"CUDA devices available: {gpu_count}") | |
| for i in range(gpu_count): | |
| mem_gb = torch.cuda.get_device_properties(i).total_memory / (1024**3) | |
| logger.info(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem_gb:.1f} GB)") | |
| # Vision model (~58GB in BF16) | |
| logger.info(f"Loading vision model: {settings.vision_model}") | |
| vision_start = time.time() | |
| try: | |
| from transformers import Qwen3VLMoeForConditionalGeneration | |
| self.models["vision"] = Qwen3VLMoeForConditionalGeneration.from_pretrained( | |
| settings.vision_model, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| self.processors["vision"] = AutoProcessor.from_pretrained( | |
| settings.vision_model, | |
| trust_remote_code=True, | |
| ) | |
| logger.info(f"Vision model loaded in {time.time() - vision_start:.2f}s") | |
| except Exception as e: | |
| logger.warning(f"Failed to load 30B vision model: {e}") | |
| logger.info(f"Falling back to {settings.vision_model_fallback}") | |
| self.models["vision"] = Qwen3VLMoeForConditionalGeneration.from_pretrained( | |
| settings.vision_model_fallback, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| self.processors["vision"] = AutoProcessor.from_pretrained( | |
| settings.vision_model_fallback, | |
| trust_remote_code=True, | |
| ) | |
| logger.info(f"Fallback vision model loaded in {time.time() - vision_start:.2f}s") | |
| # Embedding model (~16GB in BF16) - Using official Qwen3VLEmbedder | |
| logger.info(f"Loading embedding model: {settings.embedding_model}") | |
| embed_start = time.time() | |
| from scripts.qwen3_vl import Qwen3VLEmbedder | |
| self.models["embedding"] = Qwen3VLEmbedder( | |
| model_name_or_path=settings.embedding_model, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| # Processor is internal to Qwen3VLEmbedder, but store reference for compatibility | |
| self.processors["embedding"] = self.models["embedding"].processor | |
| logger.info(f"Embedding model loaded in {time.time() - embed_start:.2f}s") | |
| # Reranker model (~16GB in BF16) - Using official Qwen3VLReranker | |
| logger.info(f"Loading reranker model: {settings.reranker_model}") | |
| reranker_start = time.time() | |
| from scripts.qwen3_vl import Qwen3VLReranker | |
| self.models["reranker"] = Qwen3VLReranker( | |
| model_name_or_path=settings.reranker_model, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| # Processor is internal to Qwen3VLReranker, but store reference for compatibility | |
| self.processors["reranker"] = self.models["reranker"].processor | |
| logger.info(f"Reranker model loaded in {time.time() - reranker_start:.2f}s") | |
| self.loaded = True | |
| logger.info("All models loaded successfully") | |
| return self | |
| def is_loaded(self) -> bool: | |
| """Check if models are loaded.""" | |
| return self.loaded | |
| def vision(self) -> "RealVisionModel": | |
| """Return vision model wrapped for pipeline consumption.""" | |
| return RealVisionModel(self.models["vision"], self.processors["vision"]) | |
| def embedding(self) -> "RealEmbeddingModel": | |
| """Return embedding model wrapped for pipeline consumption.""" | |
| return RealEmbeddingModel(self.models["embedding"], self.processors["embedding"]) | |
| def reranker(self) -> "RealRerankerModel": | |
| """Return reranker model wrapped for pipeline consumption.""" | |
| return RealRerankerModel(self.models["reranker"], self.processors["reranker"]) | |
| class RealVisionModel: | |
| """Wrapper for real vision model inference.""" | |
| # System prompt for FDAM fire damage assessment (per Technical Spec Section 7) | |
| VISION_SYSTEM_PROMPT = """You are an expert industrial hygienist analyzing fire damage images for the FDAM (Fire Damage Assessment Methodology) framework. | |
| ## Your Task | |
| Analyze the provided image and extract structured information about fire damage, materials, and conditions. | |
| ## Zone Classification Criteria | |
| - **Burn Zone**: Direct fire involvement. Look for structural char, complete combustion, exposed/damaged structural elements. | |
| - **Near-Field**: Adjacent to burn zone with heavy smoke/heat exposure. Look for heavy soot deposits, heat damage (warping, discoloration), strong visible contamination. | |
| - **Far-Field**: Smoke migration without direct heat exposure. Look for light to moderate deposits, discoloration, no structural damage. | |
| ## Condition Assessment Criteria | |
| - **Background**: No visible contamination; surfaces appear normal/clean. | |
| - **Light**: Faint discoloration; minimal visible deposits; would show faint marks on white wipe test. | |
| - **Moderate**: Visible film or deposits; clear contamination; surface color noticeably altered. | |
| - **Heavy**: Thick deposits; surface texture obscured; heavy coating visible. | |
| - **Structural Damage**: Physical damage requiring repair before cleaning (charring, warping, holes, collapse). | |
| ## Material Identification | |
| Identify visible materials and categorize as: | |
| - **Non-porous**: steel, concrete, glass, metal, CMU (concrete masonry unit) | |
| - **Semi-porous**: painted drywall, sealed wood | |
| - **Porous**: unpainted drywall, carpet, insulation, acoustic tile, upholstery | |
| - **HVAC**: rigid ductwork, flexible ductwork | |
| ## Combustion Particle Visual Indicators | |
| - **Soot**: Black/dark gray coating with oily/sticky appearance; fine uniform texture; often creates "shadow" patterns | |
| - **Char**: Black angular fragments; visible wood grain or fibrous structure; larger particles | |
| - **Ash**: Gray/white powdery residue; crystalline appearance; often found with char | |
| ## Important Notes | |
| - This is VISUAL assessment only - definitive particle identification requires laboratory analysis | |
| - When uncertain between two classifications, note both with relative confidence | |
| - Flag any areas that require professional on-site verification | |
| - Note any potential access issues visible in the image""" | |
| # Analysis prompt template with JSON schema | |
| ANALYSIS_PROMPT = """Analyze this fire damage image and return a JSON response with the following structure: | |
| { | |
| "zone": { | |
| "classification": "burn" | "near-field" | "far-field", | |
| "confidence": 0.0-1.0, | |
| "reasoning": "explanation" | |
| }, | |
| "condition": { | |
| "level": "background" | "light" | "moderate" | "heavy" | "structural-damage", | |
| "confidence": 0.0-1.0, | |
| "reasoning": "explanation" | |
| }, | |
| "materials": [ | |
| { | |
| "type": "material type (e.g., drywall, concrete, steel, wood)", | |
| "category": "non-porous" | "semi-porous" | "porous" | "hvac", | |
| "confidence": 0.0-1.0, | |
| "location_description": "where in image", | |
| "bounding_box": {"x": 0.0-1.0, "y": 0.0-1.0, "width": 0.0-1.0, "height": 0.0-1.0} | |
| } | |
| ], | |
| "combustion_indicators": { | |
| "soot_visible": true/false, | |
| "soot_pattern": "description or null", | |
| "char_visible": true/false, | |
| "char_description": "description or null", | |
| "ash_visible": true/false, | |
| "ash_description": "description or null" | |
| }, | |
| "structural_concerns": ["list of structural issues if any"], | |
| "access_issues": ["list of access problems if any"], | |
| "recommended_sampling_locations": [ | |
| { | |
| "description": "where to sample", | |
| "sample_type": "tape_lift" | "surface_wipe" | "air_sample", | |
| "priority": "high" | "medium" | "low" | |
| } | |
| ], | |
| "flags_for_review": ["any items requiring human review"] | |
| } | |
| IMPORTANT: Return ONLY valid JSON, no additional text.""" | |
| def __init__(self, model, processor): | |
| self.model = model | |
| self.processor = processor | |
| def analyze_image(self, image: Image.Image, context: str = "") -> dict[str, Any]: | |
| """Analyze an image and return structured results.""" | |
| start_time = time.time() | |
| logger.debug(f"Starting vision analysis (context: {len(context)} chars)") | |
| try: | |
| from qwen_vl_utils import process_vision_info | |
| except ImportError: | |
| logger.warning("qwen_vl_utils not available, using basic processing") | |
| process_vision_info = None | |
| # Build the analysis prompt with context | |
| prompt = self.ANALYSIS_PROMPT | |
| if context: | |
| prompt = f"Context: {context}\n\n{prompt}" | |
| # Prepare messages in Qwen-VL format with system prompt | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": self.VISION_SYSTEM_PROMPT, | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| try: | |
| # Apply chat template | |
| text = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| # Process vision info if available | |
| if process_vision_info: | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = self.processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| else: | |
| # Fallback: basic image processing | |
| inputs = self.processor( | |
| text=[text], | |
| images=[image], | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| # Note: With device_map="auto", transformers handles device routing internally | |
| # Do NOT call .to(device) - it breaks distributed models | |
| # Log inference config being used | |
| logger.debug(f"Vision inference config: max_new_tokens={vision_config.max_new_tokens}, " | |
| f"do_sample={vision_config.do_sample}, temp={vision_config.temperature}") | |
| # Generate response using config values | |
| inference_start = time.time() | |
| with torch.no_grad(): | |
| if vision_config.do_sample: | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=vision_config.max_new_tokens, | |
| do_sample=True, | |
| temperature=vision_config.temperature, | |
| top_p=vision_config.top_p, | |
| repetition_penalty=vision_config.repetition_penalty, | |
| ) | |
| else: | |
| # Deterministic mode (no sampling) | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=vision_config.max_new_tokens, | |
| do_sample=False, | |
| temperature=None, | |
| top_p=None, | |
| repetition_penalty=vision_config.repetition_penalty, | |
| ) | |
| inference_time = time.time() - inference_start | |
| logger.debug(f"Vision inference completed in {inference_time:.2f}s") | |
| # Decode response | |
| response_text = self.processor.decode( | |
| outputs[0], skip_special_tokens=True | |
| ) | |
| logger.debug(f"Response length: {len(response_text)} chars") | |
| # Parse JSON from response | |
| result = self._parse_vision_response(response_text) | |
| # Log result summary | |
| total_time = time.time() - start_time | |
| zone = result.get("zone", {}).get("classification", "unknown") | |
| zone_conf = result.get("zone", {}).get("confidence", 0) | |
| condition = result.get("condition", {}).get("level", "unknown") | |
| condition_conf = result.get("condition", {}).get("confidence", 0) | |
| num_materials = len(result.get("materials", [])) | |
| logger.info(f"Vision analysis complete in {total_time:.2f}s: " | |
| f"zone={zone} ({zone_conf:.2f}), condition={condition} ({condition_conf:.2f}), " | |
| f"materials={num_materials}") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Vision analysis failed: {e}") | |
| return self._get_fallback_response(str(e)) | |
| def _parse_vision_response(self, response: str) -> dict[str, Any]: | |
| """Parse JSON response from vision model.""" | |
| try: | |
| # Try to extract JSON from response | |
| # Look for JSON block in various formats | |
| json_match = re.search(r'\{[\s\S]*\}', response) | |
| if json_match: | |
| json_str = json_match.group() | |
| return json.loads(json_str) | |
| else: | |
| logger.warning("No JSON found in vision response") | |
| return self._get_fallback_response("No JSON in response") | |
| except json.JSONDecodeError as e: | |
| logger.warning(f"Failed to parse vision JSON: {e}") | |
| return self._get_fallback_response(f"JSON parse error: {e}") | |
| def _get_fallback_response(self, reason: str) -> dict[str, Any]: | |
| """Return fallback response when analysis fails.""" | |
| return { | |
| "zone": { | |
| "classification": "far-field", | |
| "confidence": 0.3, | |
| "reasoning": f"Fallback due to: {reason}", | |
| }, | |
| "condition": { | |
| "level": "light", | |
| "confidence": 0.3, | |
| "reasoning": f"Fallback due to: {reason}", | |
| }, | |
| "materials": [ | |
| { | |
| "type": "general-surface", | |
| "category": "semi-porous", | |
| "confidence": 0.3, | |
| "location_description": "Unable to determine", | |
| "bounding_box": {"x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0}, | |
| } | |
| ], | |
| "combustion_indicators": { | |
| "soot_visible": False, | |
| "soot_pattern": None, | |
| "char_visible": False, | |
| "char_description": None, | |
| "ash_visible": False, | |
| "ash_description": None, | |
| }, | |
| "structural_concerns": [], | |
| "access_issues": [], | |
| "recommended_sampling_locations": [], | |
| "flags_for_review": [f"Analysis failed: {reason}"], | |
| "_fallback_used": True, | |
| } | |
| class RealEmbeddingModel: | |
| """Wrapper for real embedding model inference. | |
| Uses the official Qwen3VLEmbedder from QwenLM/Qwen3-VL-Embedding. | |
| The model handles last-token pooling and L2 normalization internally. | |
| """ | |
| def __init__(self, model, processor): | |
| """Initialize with Qwen3VLEmbedder instance. | |
| Args: | |
| model: Qwen3VLEmbedder instance (official loader) | |
| processor: Processor (stored for compatibility, but model has its own) | |
| """ | |
| self.model = model | |
| self.processor = processor | |
| def embed(self, text: str) -> list[float]: | |
| """Generate embedding for text using official Qwen3VLEmbedder. | |
| The official model.process() handles: | |
| - Tokenization and preprocessing | |
| - Last-token pooling | |
| - L2 normalization | |
| Args: | |
| text: Input text to embed | |
| Returns: | |
| List of floats representing the embedding (4096-dim for 8B model) | |
| """ | |
| try: | |
| # Use official process() API - expects list of dicts | |
| inputs = [{"text": text}] | |
| embeddings = self.model.process(inputs, normalize=True) | |
| # embeddings is a tensor of shape (1, hidden_dim) | |
| return embeddings[0].cpu().tolist() | |
| except Exception as e: | |
| logger.error(f"Embedding generation failed: {e}") | |
| # Return zero vector as fallback (4096-dim per Qwen3-VL-Embedding-8B) | |
| hidden_size = getattr(self.model.model.config, "hidden_size", 4096) | |
| return [0.0] * hidden_size | |
| def embed_batch(self, texts: list[str]) -> list[list[float]]: | |
| """Generate embeddings for a batch of texts. | |
| Uses official batch processing for efficiency. | |
| """ | |
| try: | |
| inputs = [{"text": text} for text in texts] | |
| embeddings = self.model.process(inputs, normalize=True) | |
| return [emb.cpu().tolist() for emb in embeddings] | |
| except Exception as e: | |
| logger.error(f"Batch embedding generation failed: {e}") | |
| hidden_size = getattr(self.model.model.config, "hidden_size", 4096) | |
| return [[0.0] * hidden_size for _ in texts] | |
| class RealRerankerModel: | |
| """Wrapper for real reranker model inference. | |
| Uses the official Qwen3VLReranker from QwenLM/Qwen3-VL-Embedding. | |
| The model handles yes/no scoring internally via: | |
| - Extracts "yes" and "no" token weights from the LM head | |
| - Creates a binary linear layer: weight = yes_weight - no_weight | |
| - Scores = sigmoid(linear(last_token_hidden_state)) | |
| Reference: https://github.com/QwenLM/Qwen3-VL-Embedding | |
| """ | |
| def __init__(self, model, processor): | |
| """Initialize with Qwen3VLReranker instance. | |
| Args: | |
| model: Qwen3VLReranker instance (official loader) | |
| processor: Processor (stored for compatibility, but model has its own) | |
| """ | |
| self.model = model | |
| self.processor = processor | |
| def rerank(self, query: str, documents: list[str]) -> list[float]: | |
| """Rerank documents by relevance to query using official Qwen3VLReranker. | |
| The official model.process() handles: | |
| - Proper message formatting | |
| - Tokenization | |
| - Yes/no scoring with LM head weights | |
| - Sigmoid normalization | |
| Args: | |
| query: The search query | |
| documents: List of documents to rerank | |
| Returns: | |
| List of relevance scores (0-1) for each document. | |
| Higher scores indicate more relevant documents. | |
| """ | |
| if not documents: | |
| return [] | |
| try: | |
| # Use official process() API - expects dict with query and documents | |
| inputs = { | |
| "instruction": "Retrieve relevant documents for the query.", | |
| "query": {"text": query}, | |
| "documents": [{"text": doc} for doc in documents], | |
| } | |
| scores = self.model.process(inputs) | |
| return scores | |
| except Exception as e: | |
| logger.error(f"Reranking failed: {e}") | |
| return [0.0] * len(documents) | |
| def rerank_with_indices( | |
| self, query: str, documents: list[str], top_k: int = None | |
| ) -> list[tuple[int, float]]: | |
| """Rerank and return sorted (index, score) tuples. | |
| Args: | |
| query: The search query | |
| documents: List of documents to rerank | |
| top_k: Optional limit on number of results | |
| Returns: | |
| List of (original_index, score) tuples, sorted by score descending | |
| """ | |
| scores = self.rerank(query, documents) | |
| # Create (index, score) pairs and sort by score descending | |
| indexed_scores = list(enumerate(scores)) | |
| indexed_scores.sort(key=lambda x: x[1], reverse=True) | |
| if top_k is not None: | |
| indexed_scores = indexed_scores[:top_k] | |
| return indexed_scores | |