| | """ |
| | ShortSmith v2 - Visual Analyzer Module |
| | |
| | Visual analysis using Qwen2-VL-2B for: |
| | - Scene understanding and description |
| | - Action/event detection |
| | - Emotion recognition |
| | - Visual hype scoring |
| | |
| | Uses quantization (INT4/INT8) for efficient inference on consumer GPUs. |
| | """ |
| |
|
| | from pathlib import Path |
| | from typing import List, Optional, Dict, Any, Union |
| | from dataclasses import dataclass |
| | import numpy as np |
| |
|
| | from utils.logger import get_logger, LogTimer |
| | from utils.helpers import ModelLoadError, InferenceError |
| | from config import get_config, ModelConfig |
| |
|
| | logger = get_logger("models.visual_analyzer") |
| |
|
| |
|
| | @dataclass |
| | class VisualFeatures: |
| | """Visual features extracted from a frame or video segment.""" |
| | timestamp: float |
| | description: str |
| | hype_score: float |
| | action_detected: str |
| | emotion: str |
| | scene_type: str |
| | confidence: float |
| |
|
| | |
| | embedding: Optional[np.ndarray] = None |
| |
|
| | def to_dict(self) -> Dict[str, Any]: |
| | """Convert to dictionary.""" |
| | return { |
| | "timestamp": self.timestamp, |
| | "description": self.description, |
| | "hype_score": self.hype_score, |
| | "action": self.action_detected, |
| | "emotion": self.emotion, |
| | "scene_type": self.scene_type, |
| | "confidence": self.confidence, |
| | } |
| |
|
| |
|
| | class VisualAnalyzer: |
| | """ |
| | Visual analysis using Qwen2-VL-2B model. |
| | |
| | Supports: |
| | - Single frame analysis |
| | - Batch processing |
| | - Video segment understanding |
| | - Custom prompt-based analysis |
| | """ |
| |
|
| | |
| | HYPE_PROMPT = """Analyze this image and rate its excitement/hype level from 0 to 10. |
| | Consider: action intensity, crowd energy, dramatic moments, emotional peaks. |
| | Respond with just a number from 0-10.""" |
| |
|
| | DESCRIPTION_PROMPT = """Briefly describe what's happening in this image in one sentence. |
| | Focus on the main action, people, and mood.""" |
| |
|
| | ACTION_PROMPT = """What action or event is happening in this image? |
| | Choose from: celebration, performance, speech, reaction, action, calm, transition, other. |
| | Respond with just the action type.""" |
| |
|
| | EMOTION_PROMPT = """What is the dominant emotion or mood in this image? |
| | Choose from: excitement, joy, tension, surprise, calm, sadness, anger, neutral. |
| | Respond with just the emotion.""" |
| |
|
| | def __init__( |
| | self, |
| | config: Optional[ModelConfig] = None, |
| | load_model: bool = True, |
| | ): |
| | """ |
| | Initialize visual analyzer. |
| | |
| | Args: |
| | config: Model configuration (uses default if None) |
| | load_model: Whether to load model immediately |
| | |
| | Raises: |
| | ModelLoadError: If model loading fails |
| | """ |
| | self.config = config or get_config().model |
| | self.model = None |
| | self.processor = None |
| | self._device = None |
| |
|
| | if load_model: |
| | self._load_model() |
| |
|
| | logger.info(f"VisualAnalyzer initialized (model={self.config.visual_model_id})") |
| |
|
| | def _load_model(self) -> None: |
| | """Load the Qwen2-VL model with quantization.""" |
| | with LogTimer(logger, "Loading Qwen2-VL model"): |
| | try: |
| | import os |
| | import torch |
| | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor |
| |
|
| | |
| | hf_token = os.environ.get("HF_TOKEN") |
| |
|
| | |
| | if self.config.device == "cuda" and torch.cuda.is_available(): |
| | self._device = "cuda" |
| | else: |
| | self._device = "cpu" |
| |
|
| | logger.info(f"Loading model on {self._device}") |
| |
|
| | |
| | self.processor = AutoProcessor.from_pretrained( |
| | self.config.visual_model_id, |
| | trust_remote_code=True, |
| | token=hf_token, |
| | ) |
| |
|
| | |
| | model_kwargs = { |
| | "trust_remote_code": True, |
| | "device_map": "auto" if self._device == "cuda" else None, |
| | } |
| |
|
| | |
| | if self.config.visual_model_quantization == "int4": |
| | try: |
| | from transformers import BitsAndBytesConfig |
| |
|
| | quantization_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_compute_dtype=torch.float16, |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_quant_type="nf4", |
| | ) |
| | model_kwargs["quantization_config"] = quantization_config |
| | logger.info("Using INT4 quantization") |
| | except ImportError: |
| | logger.warning("bitsandbytes not available, loading without quantization") |
| |
|
| | elif self.config.visual_model_quantization == "int8": |
| | try: |
| | from transformers import BitsAndBytesConfig |
| |
|
| | quantization_config = BitsAndBytesConfig( |
| | load_in_8bit=True, |
| | ) |
| | model_kwargs["quantization_config"] = quantization_config |
| | logger.info("Using INT8 quantization") |
| | except ImportError: |
| | logger.warning("bitsandbytes not available, loading without quantization") |
| |
|
| | self.model = Qwen2VLForConditionalGeneration.from_pretrained( |
| | self.config.visual_model_id, |
| | token=hf_token, |
| | **model_kwargs, |
| | ) |
| |
|
| | if self._device == "cpu": |
| | self.model = self.model.to(self._device) |
| |
|
| | self.model.eval() |
| | logger.info("Qwen2-VL model loaded successfully") |
| |
|
| | except Exception as e: |
| | logger.error(f"Failed to load Qwen2-VL model: {e}") |
| | raise ModelLoadError(f"Could not load visual model: {e}") from e |
| |
|
| | def analyze_frame( |
| | self, |
| | image: Union[str, Path, np.ndarray, "PIL.Image.Image"], |
| | timestamp: float = 0.0, |
| | ) -> VisualFeatures: |
| | """ |
| | Analyze a single frame/image. |
| | |
| | Args: |
| | image: Image path, numpy array, or PIL Image |
| | timestamp: Timestamp for this frame |
| | |
| | Returns: |
| | VisualFeatures with analysis results |
| | |
| | Raises: |
| | InferenceError: If analysis fails |
| | """ |
| | if self.model is None: |
| | raise ModelLoadError("Model not loaded. Call _load_model() first.") |
| |
|
| | try: |
| | from PIL import Image as PILImage |
| |
|
| | |
| | if isinstance(image, (str, Path)): |
| | pil_image = PILImage.open(image).convert("RGB") |
| | elif isinstance(image, np.ndarray): |
| | pil_image = PILImage.fromarray(image).convert("RGB") |
| | else: |
| | pil_image = image |
| |
|
| | |
| | hype_score = self._get_hype_score(pil_image) |
| | description = self._get_description(pil_image) |
| | action = self._get_action(pil_image) |
| | emotion = self._get_emotion(pil_image) |
| |
|
| | return VisualFeatures( |
| | timestamp=timestamp, |
| | description=description, |
| | hype_score=hype_score, |
| | action_detected=action, |
| | emotion=emotion, |
| | scene_type=self._classify_scene(action, emotion), |
| | confidence=0.8, |
| | ) |
| |
|
| | except Exception as e: |
| | logger.error(f"Frame analysis failed: {e}") |
| | raise InferenceError(f"Visual analysis failed: {e}") from e |
| |
|
| | def _query_model( |
| | self, |
| | image: "PIL.Image.Image", |
| | prompt: str, |
| | max_tokens: int = 50, |
| | ) -> str: |
| | """Send a query to the model and get response.""" |
| | import torch |
| |
|
| | try: |
| | |
| | messages = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "image", "image": image}, |
| | {"type": "text", "text": prompt}, |
| | ], |
| | } |
| | ] |
| |
|
| | |
| | text = self.processor.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=True |
| | ) |
| |
|
| | inputs = self.processor( |
| | text=[text], |
| | images=[image], |
| | padding=True, |
| | return_tensors="pt", |
| | ) |
| |
|
| | if self._device == "cuda": |
| | inputs = {k: v.cuda() if hasattr(v, 'cuda') else v for k, v in inputs.items()} |
| |
|
| | |
| | with torch.no_grad(): |
| | output_ids = self.model.generate( |
| | **inputs, |
| | max_new_tokens=max_tokens, |
| | do_sample=False, |
| | ) |
| |
|
| | |
| | response = self.processor.batch_decode( |
| | output_ids[:, inputs['input_ids'].shape[1]:], |
| | skip_special_tokens=True, |
| | )[0] |
| |
|
| | return response.strip() |
| |
|
| | except Exception as e: |
| | logger.warning(f"Model query failed: {e}") |
| | return "" |
| |
|
| | def _get_hype_score(self, image: "PIL.Image.Image") -> float: |
| | """Get hype score from model.""" |
| | response = self._query_model(image, self.HYPE_PROMPT, max_tokens=10) |
| |
|
| | try: |
| | |
| | import re |
| | numbers = re.findall(r'\d+(?:\.\d+)?', response) |
| | if numbers: |
| | score = float(numbers[0]) |
| | return min(1.0, score / 10.0) |
| | except (ValueError, IndexError): |
| | pass |
| |
|
| | return 0.5 |
| |
|
| | def _get_description(self, image: "PIL.Image.Image") -> str: |
| | """Get description from model.""" |
| | response = self._query_model(image, self.DESCRIPTION_PROMPT, max_tokens=100) |
| | return response if response else "Unable to describe" |
| |
|
| | def _get_action(self, image: "PIL.Image.Image") -> str: |
| | """Get action type from model.""" |
| | response = self._query_model(image, self.ACTION_PROMPT, max_tokens=20) |
| | actions = ["celebration", "performance", "speech", "reaction", "action", "calm", "transition", "other"] |
| |
|
| | response_lower = response.lower() |
| | for action in actions: |
| | if action in response_lower: |
| | return action |
| |
|
| | return "other" |
| |
|
| | def _get_emotion(self, image: "PIL.Image.Image") -> str: |
| | """Get emotion from model.""" |
| | response = self._query_model(image, self.EMOTION_PROMPT, max_tokens=20) |
| | emotions = ["excitement", "joy", "tension", "surprise", "calm", "sadness", "anger", "neutral"] |
| |
|
| | response_lower = response.lower() |
| | for emotion in emotions: |
| | if emotion in response_lower: |
| | return emotion |
| |
|
| | return "neutral" |
| |
|
| | def _classify_scene(self, action: str, emotion: str) -> str: |
| | """Classify scene type based on action and emotion.""" |
| | high_energy = {"celebration", "performance", "action"} |
| | high_emotion = {"excitement", "joy", "surprise", "tension"} |
| |
|
| | if action in high_energy and emotion in high_emotion: |
| | return "highlight" |
| | elif action in high_energy: |
| | return "active" |
| | elif emotion in high_emotion: |
| | return "emotional" |
| | else: |
| | return "neutral" |
| |
|
| | def analyze_frames_batch( |
| | self, |
| | images: List[Union[str, Path, np.ndarray]], |
| | timestamps: Optional[List[float]] = None, |
| | ) -> List[VisualFeatures]: |
| | """ |
| | Analyze multiple frames in batches. |
| | |
| | Args: |
| | images: List of images (paths or arrays) |
| | timestamps: Timestamps for each image |
| | |
| | Returns: |
| | List of VisualFeatures for each image |
| | """ |
| | if timestamps is None: |
| | timestamps = [i * 1.0 for i in range(len(images))] |
| |
|
| | results = [] |
| |
|
| | with LogTimer(logger, f"Analyzing {len(images)} frames"): |
| | for i, (image, ts) in enumerate(zip(images, timestamps)): |
| | try: |
| | features = self.analyze_frame(image, timestamp=ts) |
| | results.append(features) |
| |
|
| | if (i + 1) % 10 == 0: |
| | logger.debug(f"Processed {i + 1}/{len(images)} frames") |
| |
|
| | except Exception as e: |
| | logger.warning(f"Failed to analyze frame {i}: {e}") |
| | |
| | results.append(VisualFeatures( |
| | timestamp=ts, |
| | description="Analysis failed", |
| | hype_score=0.5, |
| | action_detected="unknown", |
| | emotion="neutral", |
| | scene_type="neutral", |
| | confidence=0.0, |
| | )) |
| |
|
| | return results |
| |
|
| | def analyze_with_custom_prompt( |
| | self, |
| | image: Union[str, Path, np.ndarray, "PIL.Image.Image"], |
| | prompt: str, |
| | timestamp: float = 0.0, |
| | ) -> Dict[str, Any]: |
| | """ |
| | Analyze image with a custom prompt. |
| | |
| | Args: |
| | image: Image to analyze |
| | prompt: Custom analysis prompt from user |
| | timestamp: Timestamp for this frame |
| | |
| | Returns: |
| | Dictionary with prompt, response, and timestamp |
| | """ |
| | from PIL import Image as PILImage |
| |
|
| | |
| | if isinstance(image, (str, Path)): |
| | pil_image = PILImage.open(image).convert("RGB") |
| | elif isinstance(image, np.ndarray): |
| | pil_image = PILImage.fromarray(image).convert("RGB") |
| | else: |
| | pil_image = image |
| |
|
| | |
| | formatted_prompt = ( |
| | f"Analyze this frame for the following criteria: {prompt}\n\n" |
| | f"Does this frame match these criteria? " |
| | f"Respond with 'Yes' or 'No' followed by a brief explanation. " |
| | f"If it partially matches, say 'Partially' and explain what matches." |
| | ) |
| |
|
| | response = self._query_model(pil_image, formatted_prompt, max_tokens=200) |
| |
|
| | return { |
| | "timestamp": timestamp, |
| | "prompt": prompt, |
| | "response": response, |
| | } |
| |
|
| | def get_frame_embedding( |
| | self, |
| | image: Union[str, Path, np.ndarray, "PIL.Image.Image"], |
| | ) -> Optional[np.ndarray]: |
| | """ |
| | Get visual embedding for a frame. |
| | |
| | Args: |
| | image: Image to embed |
| | |
| | Returns: |
| | Embedding array or None if failed |
| | """ |
| | |
| | |
| | logger.warning("Frame embedding not directly supported by Qwen2-VL") |
| | return None |
| |
|
| | def unload_model(self) -> None: |
| | """Unload model to free GPU memory.""" |
| | if self.model is not None: |
| | del self.model |
| | self.model = None |
| |
|
| | if self.processor is not None: |
| | del self.processor |
| | self.processor = None |
| |
|
| | |
| | try: |
| | import torch |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | except ImportError: |
| | pass |
| |
|
| | logger.info("Visual model unloaded") |
| |
|
| |
|
| | |
| | __all__ = ["VisualAnalyzer", "VisualFeatures"] |
| |
|