Spaces:
Paused
Paused
| """ | |
| ShortSmith v2 - Visual Analyzer Module | |
| Visual analysis using Qwen2-VL-2B for: | |
| - Scene understanding and description | |
| - Action/event detection | |
| - Emotion recognition | |
| - Visual hype scoring | |
| Uses quantization (INT4/INT8) for efficient inference on consumer GPUs. | |
| """ | |
| from pathlib import Path | |
| from typing import List, Optional, Dict, Any, Union | |
| from dataclasses import dataclass | |
| import numpy as np | |
| from utils.logger import get_logger, LogTimer | |
| from utils.helpers import ModelLoadError, InferenceError, batch_list | |
| from config import get_config, ModelConfig | |
| logger = get_logger("models.visual_analyzer") | |
| class VisualFeatures: | |
| """Visual features extracted from a frame or video segment.""" | |
| timestamp: float # Timestamp in seconds | |
| description: str # Natural language description | |
| hype_score: float # Visual excitement score (0-1) | |
| action_detected: str # Detected action/event | |
| emotion: str # Detected emotion/mood | |
| scene_type: str # Scene classification | |
| confidence: float # Model confidence (0-1) | |
| # Raw embedding if available | |
| embedding: Optional[np.ndarray] = None | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert to dictionary.""" | |
| return { | |
| "timestamp": self.timestamp, | |
| "description": self.description, | |
| "hype_score": self.hype_score, | |
| "action": self.action_detected, | |
| "emotion": self.emotion, | |
| "scene_type": self.scene_type, | |
| "confidence": self.confidence, | |
| } | |
| class VisualAnalyzer: | |
| """ | |
| Visual analysis using Qwen2-VL-2B model. | |
| Supports: | |
| - Single frame analysis | |
| - Batch processing | |
| - Video segment understanding | |
| - Custom prompt-based analysis | |
| """ | |
| # Prompts for different analysis tasks | |
| HYPE_PROMPT = """Analyze this image and rate its excitement/hype level from 0 to 10. | |
| Consider: action intensity, crowd energy, dramatic moments, emotional peaks. | |
| Respond with just a number from 0-10.""" | |
| DESCRIPTION_PROMPT = """Briefly describe what's happening in this image in one sentence. | |
| Focus on the main action, people, and mood.""" | |
| ACTION_PROMPT = """What action or event is happening in this image? | |
| Choose from: celebration, performance, speech, reaction, action, calm, transition, other. | |
| Respond with just the action type.""" | |
| EMOTION_PROMPT = """What is the dominant emotion or mood in this image? | |
| Choose from: excitement, joy, tension, surprise, calm, sadness, anger, neutral. | |
| Respond with just the emotion.""" | |
| def __init__( | |
| self, | |
| config: Optional[ModelConfig] = None, | |
| load_model: bool = True, | |
| ): | |
| """ | |
| Initialize visual analyzer. | |
| Args: | |
| config: Model configuration (uses default if None) | |
| load_model: Whether to load model immediately | |
| Raises: | |
| ModelLoadError: If model loading fails | |
| """ | |
| self.config = config or get_config().model | |
| self.model = None | |
| self.processor = None | |
| self._device = None | |
| if load_model: | |
| self._load_model() | |
| logger.info(f"VisualAnalyzer initialized (model={self.config.visual_model_id})") | |
| def _load_model(self) -> None: | |
| """Load the Qwen2-VL model with quantization.""" | |
| with LogTimer(logger, "Loading Qwen2-VL model"): | |
| try: | |
| import os | |
| import torch | |
| from transformers import Qwen2VLForConditionalGeneration, AutoProcessor | |
| # Get HuggingFace token from environment (optional - model is open access) | |
| hf_token = os.environ.get("HF_TOKEN") | |
| # Determine device | |
| if self.config.device == "cuda" and torch.cuda.is_available(): | |
| self._device = "cuda" | |
| else: | |
| self._device = "cpu" | |
| logger.info(f"Loading model on {self._device}") | |
| # Load processor | |
| self.processor = AutoProcessor.from_pretrained( | |
| self.config.visual_model_id, | |
| trust_remote_code=True, | |
| token=hf_token, | |
| ) | |
| # Load model with quantization | |
| model_kwargs = { | |
| "trust_remote_code": True, | |
| "device_map": "auto" if self._device == "cuda" else None, | |
| } | |
| # Apply quantization if requested | |
| if self.config.visual_model_quantization == "int4": | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| ) | |
| model_kwargs["quantization_config"] = quantization_config | |
| logger.info("Using INT4 quantization") | |
| except ImportError: | |
| logger.warning("bitsandbytes not available, loading without quantization") | |
| elif self.config.visual_model_quantization == "int8": | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| ) | |
| model_kwargs["quantization_config"] = quantization_config | |
| logger.info("Using INT8 quantization") | |
| except ImportError: | |
| logger.warning("bitsandbytes not available, loading without quantization") | |
| self.model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| self.config.visual_model_id, | |
| token=hf_token, | |
| **model_kwargs, | |
| ) | |
| if self._device == "cpu": | |
| self.model = self.model.to(self._device) | |
| self.model.eval() | |
| logger.info("Qwen2-VL model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load Qwen2-VL model: {e}") | |
| raise ModelLoadError(f"Could not load visual model: {e}") from e | |
| def analyze_frame( | |
| self, | |
| image: Union[str, Path, np.ndarray, "PIL.Image.Image"], | |
| prompt: Optional[str] = None, | |
| timestamp: float = 0.0, | |
| ) -> VisualFeatures: | |
| """ | |
| Analyze a single frame/image. | |
| Args: | |
| image: Image path, numpy array, or PIL Image | |
| prompt: Custom prompt (uses default if None) | |
| timestamp: Timestamp for this frame | |
| Returns: | |
| VisualFeatures with analysis results | |
| Raises: | |
| InferenceError: If analysis fails | |
| """ | |
| if self.model is None: | |
| raise ModelLoadError("Model not loaded. Call _load_model() first.") | |
| try: | |
| from PIL import Image as PILImage | |
| # Load image if path | |
| if isinstance(image, (str, Path)): | |
| pil_image = PILImage.open(image).convert("RGB") | |
| elif isinstance(image, np.ndarray): | |
| pil_image = PILImage.fromarray(image).convert("RGB") | |
| else: | |
| pil_image = image | |
| # Get various analyses | |
| hype_score = self._get_hype_score(pil_image) | |
| description = self._get_description(pil_image) | |
| action = self._get_action(pil_image) | |
| emotion = self._get_emotion(pil_image) | |
| return VisualFeatures( | |
| timestamp=timestamp, | |
| description=description, | |
| hype_score=hype_score, | |
| action_detected=action, | |
| emotion=emotion, | |
| scene_type=self._classify_scene(action, emotion), | |
| confidence=0.8, # Default confidence | |
| ) | |
| except Exception as e: | |
| logger.error(f"Frame analysis failed: {e}") | |
| raise InferenceError(f"Visual analysis failed: {e}") from e | |
| def _query_model( | |
| self, | |
| image: "PIL.Image.Image", | |
| prompt: str, | |
| max_tokens: int = 50, | |
| ) -> str: | |
| """Send a query to the model and get response.""" | |
| import torch | |
| try: | |
| # Prepare messages in Qwen2-VL format | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| # Process inputs | |
| text = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = self.processor( | |
| text=[text], | |
| images=[image], | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| if self._device == "cuda": | |
| inputs = {k: v.cuda() if hasattr(v, 'cuda') else v for k, v in inputs.items()} | |
| # Generate | |
| with torch.no_grad(): | |
| output_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=False, | |
| ) | |
| # Decode response | |
| response = self.processor.batch_decode( | |
| output_ids[:, inputs['input_ids'].shape[1]:], | |
| skip_special_tokens=True, | |
| )[0] | |
| return response.strip() | |
| except Exception as e: | |
| logger.warning(f"Model query failed: {e}") | |
| return "" | |
| def _get_hype_score(self, image: "PIL.Image.Image") -> float: | |
| """Get hype score from model.""" | |
| response = self._query_model(image, self.HYPE_PROMPT, max_tokens=10) | |
| try: | |
| # Extract number from response | |
| import re | |
| numbers = re.findall(r'\d+(?:\.\d+)?', response) | |
| if numbers: | |
| score = float(numbers[0]) | |
| return min(1.0, score / 10.0) # Normalize to 0-1 | |
| except (ValueError, IndexError): | |
| pass | |
| return 0.5 # Default middle score | |
| def _get_description(self, image: "PIL.Image.Image") -> str: | |
| """Get description from model.""" | |
| response = self._query_model(image, self.DESCRIPTION_PROMPT, max_tokens=100) | |
| return response if response else "Unable to describe" | |
| def _get_action(self, image: "PIL.Image.Image") -> str: | |
| """Get action type from model.""" | |
| response = self._query_model(image, self.ACTION_PROMPT, max_tokens=20) | |
| actions = ["celebration", "performance", "speech", "reaction", "action", "calm", "transition", "other"] | |
| response_lower = response.lower() | |
| for action in actions: | |
| if action in response_lower: | |
| return action | |
| return "other" | |
| def _get_emotion(self, image: "PIL.Image.Image") -> str: | |
| """Get emotion from model.""" | |
| response = self._query_model(image, self.EMOTION_PROMPT, max_tokens=20) | |
| emotions = ["excitement", "joy", "tension", "surprise", "calm", "sadness", "anger", "neutral"] | |
| response_lower = response.lower() | |
| for emotion in emotions: | |
| if emotion in response_lower: | |
| return emotion | |
| return "neutral" | |
| def _classify_scene(self, action: str, emotion: str) -> str: | |
| """Classify scene type based on action and emotion.""" | |
| high_energy = {"celebration", "performance", "action"} | |
| high_emotion = {"excitement", "joy", "surprise", "tension"} | |
| if action in high_energy and emotion in high_emotion: | |
| return "highlight" | |
| elif action in high_energy: | |
| return "active" | |
| elif emotion in high_emotion: | |
| return "emotional" | |
| else: | |
| return "neutral" | |
| def analyze_frames_batch( | |
| self, | |
| images: List[Union[str, Path, np.ndarray]], | |
| timestamps: Optional[List[float]] = None, | |
| batch_size: int = 4, | |
| ) -> List[VisualFeatures]: | |
| """ | |
| Analyze multiple frames in batches. | |
| Args: | |
| images: List of images (paths or arrays) | |
| timestamps: Timestamps for each image | |
| batch_size: Number of images per batch | |
| Returns: | |
| List of VisualFeatures for each image | |
| """ | |
| if timestamps is None: | |
| timestamps = [i * 1.0 for i in range(len(images))] | |
| results = [] | |
| with LogTimer(logger, f"Analyzing {len(images)} frames"): | |
| for i, (image, ts) in enumerate(zip(images, timestamps)): | |
| try: | |
| features = self.analyze_frame(image, timestamp=ts) | |
| results.append(features) | |
| if (i + 1) % 10 == 0: | |
| logger.debug(f"Processed {i + 1}/{len(images)} frames") | |
| except Exception as e: | |
| logger.warning(f"Failed to analyze frame {i}: {e}") | |
| # Add placeholder | |
| results.append(VisualFeatures( | |
| timestamp=ts, | |
| description="Analysis failed", | |
| hype_score=0.5, | |
| action_detected="unknown", | |
| emotion="neutral", | |
| scene_type="neutral", | |
| confidence=0.0, | |
| )) | |
| return results | |
| def analyze_with_custom_prompt( | |
| self, | |
| image: Union[str, Path, np.ndarray, "PIL.Image.Image"], | |
| prompt: str, | |
| timestamp: float = 0.0, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Analyze image with a custom prompt. | |
| Args: | |
| image: Image to analyze | |
| prompt: Custom analysis prompt | |
| timestamp: Timestamp for this frame | |
| Returns: | |
| Dictionary with prompt, response, and timestamp | |
| """ | |
| from PIL import Image as PILImage | |
| # Load image if needed | |
| if isinstance(image, (str, Path)): | |
| pil_image = PILImage.open(image).convert("RGB") | |
| elif isinstance(image, np.ndarray): | |
| pil_image = PILImage.fromarray(image).convert("RGB") | |
| else: | |
| pil_image = image | |
| response = self._query_model(pil_image, prompt, max_tokens=200) | |
| return { | |
| "timestamp": timestamp, | |
| "prompt": prompt, | |
| "response": response, | |
| } | |
| def get_frame_embedding( | |
| self, | |
| image: Union[str, Path, np.ndarray, "PIL.Image.Image"], | |
| ) -> Optional[np.ndarray]: | |
| """ | |
| Get visual embedding for a frame. | |
| Args: | |
| image: Image to embed | |
| Returns: | |
| Embedding array or None if failed | |
| """ | |
| # Note: Qwen2-VL doesn't directly expose embeddings | |
| # This would need a different approach or model | |
| logger.warning("Frame embedding not directly supported by Qwen2-VL") | |
| return None | |
| def unload_model(self) -> None: | |
| """Unload model to free GPU memory.""" | |
| if self.model is not None: | |
| del self.model | |
| self.model = None | |
| if self.processor is not None: | |
| del self.processor | |
| self.processor = None | |
| # Clear CUDA cache | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except ImportError: | |
| pass | |
| logger.info("Visual model unloaded") | |
| # Export public interface | |
| __all__ = ["VisualAnalyzer", "VisualFeatures"] | |