Spaces:
Sleeping
Sleeping
| """ | |
| SmolVLM2 Model Handler | |
| Handles loading and inference with SmolVLM2-256M-Instruct model (smallest model for HuggingFace Spaces) | |
| """ | |
| import os | |
| import tempfile | |
| # Set cache directories to writable locations for HuggingFace Spaces | |
| if 'HF_HOME' not in os.environ: | |
| # Use /tmp which is guaranteed to be writable in containers | |
| CACHE_DIR = os.path.join("/tmp", ".cache", "huggingface") | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| os.makedirs(os.path.join("/tmp", ".cache", "torch"), exist_ok=True) | |
| os.environ['HF_HOME'] = CACHE_DIR | |
| os.environ['HF_DATASETS_CACHE'] = CACHE_DIR | |
| os.environ['TORCH_HOME'] = os.path.join("/tmp", ".cache", "torch") | |
| os.environ['XDG_CACHE_HOME'] = os.path.join("/tmp", ".cache") | |
| os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR | |
| os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
| import torch | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| from PIL import Image | |
| import requests | |
| from typing import List, Union, Optional | |
| import logging | |
| import warnings | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Suppress some warnings for cleaner output | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| class SmolVLM2Handler: | |
| """Handler for SmolVLM2 model operations""" | |
| def __init__(self, model_name: str = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct", device: str = "auto"): | |
| """ | |
| Initialize SmolVLM2 model (2.2B version - better reasoning capabilities) | |
| Args: | |
| model_name: HuggingFace model identifier | |
| device: Device to use ('auto', 'cpu', 'cuda', 'mps') | |
| """ | |
| self.model_name = model_name | |
| self.device = self._get_device(device) | |
| self.model = None | |
| self.processor = None | |
| logger.info(f"Initializing SmolVLM2 on device: {self.device}") | |
| self._load_model() | |
| def _get_device(self, device: str) -> str: | |
| """Determine the best device to use with graceful fallback.""" | |
| if device == "auto": | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| requested = device.lower() | |
| if requested == "cuda" and not torch.cuda.is_available(): | |
| logger.warning("CUDA requested but not available. Falling back to CPU.") | |
| return "cpu" | |
| if requested == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()): | |
| logger.warning("MPS requested but not available. Falling back to CPU.") | |
| return "cpu" | |
| return requested | |
| def _get_torch_dtype(self) -> torch.dtype: | |
| """Pick dtype based on the selected device.""" | |
| if self.device == "cuda": | |
| return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| if self.device == "mps": | |
| return torch.float16 | |
| return torch.float32 | |
| def _load_model(self): | |
| """Load the model and processor""" | |
| try: | |
| logger.info("Loading processor...") | |
| self.processor = AutoProcessor.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True | |
| ) | |
| logger.info("Loading model...") | |
| dtype = self._get_torch_dtype() | |
| logger.info(f"Using torch dtype: {dtype}") | |
| try: | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| self.model_name, | |
| dtype=dtype, | |
| trust_remote_code=True | |
| ) | |
| except TypeError: | |
| # Backward compatibility for older Transformers versions. | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| self.model_name, | |
| torch_dtype=dtype, | |
| trust_remote_code=True | |
| ) | |
| self.model = self.model.to(self.device) | |
| logger.info("β Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"β Failed to load model: {e}") | |
| raise | |
| def process_image(self, image_input: Union[str, Image.Image]) -> Image.Image: | |
| """ | |
| Process image input into PIL Image | |
| Args: | |
| image_input: File path, URL, or PIL Image | |
| Returns: | |
| PIL Image object | |
| """ | |
| if isinstance(image_input, str): | |
| if image_input.startswith(('http://', 'https://')): | |
| # Download from URL | |
| response = requests.get(image_input) | |
| image = Image.open(requests.get(image_input, stream=True).raw) | |
| else: | |
| # Load from file path | |
| image = Image.open(image_input) | |
| elif isinstance(image_input, Image.Image): | |
| image = image_input | |
| else: | |
| raise ValueError("Image input must be file path, URL, or PIL Image") | |
| # Convert to RGB if necessary | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| return image | |
| def generate_response( | |
| self, | |
| image_input: Union[str, Image.Image, List[Image.Image]], | |
| text_prompt: str, | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.7, | |
| do_sample: bool = True | |
| ) -> str: | |
| """ | |
| Generate response from image(s) and text prompt | |
| Args: | |
| image_input: Single image or list of images | |
| text_prompt: Text prompt/question | |
| max_new_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| do_sample: Whether to use sampling | |
| Returns: | |
| Generated text response | |
| """ | |
| try: | |
| # Process images | |
| if isinstance(image_input, list): | |
| images = [self.process_image(img) for img in image_input] | |
| else: | |
| images = [self.process_image(image_input)] | |
| # Create proper conversation format for SmolVLM2 | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": text_prompt}] | |
| } | |
| ] | |
| # Add image content to the message | |
| for img in images: | |
| messages[0]["content"].insert(0, {"type": "image", "image": img}) | |
| # Apply chat template | |
| try: | |
| prompt = self.processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True | |
| ) | |
| except: | |
| # Fallback to simple format if chat template fails | |
| image_tokens = "<image>" * len(images) | |
| prompt = f"{image_tokens}{text_prompt}" | |
| # Prepare inputs | |
| inputs = self.processor( | |
| images=images, | |
| text=prompt, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Generate response with robust parameters optimized for scoring | |
| with torch.no_grad(): | |
| try: | |
| generated_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=do_sample, | |
| top_p=0.85, # Slightly lower top_p for more focused responses | |
| top_k=40, # Add top_k for better control | |
| repetition_penalty=1.2, # Higher repetition penalty | |
| pad_token_id=self.processor.tokenizer.eos_token_id, | |
| eos_token_id=self.processor.tokenizer.eos_token_id, | |
| use_cache=True | |
| ) | |
| except RuntimeError as e: | |
| if "probability tensor" in str(e) or "nan" in str(e) or "inf" in str(e): | |
| # Retry with more conservative parameters | |
| logger.warning("Retrying with conservative parameters due to probability tensor error") | |
| generated_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=min(max_new_tokens, 256), | |
| temperature=min(temperature, 0.5), | |
| do_sample=do_sample, | |
| top_p=0.9, | |
| pad_token_id=self.processor.tokenizer.eos_token_id, | |
| eos_token_id=self.processor.tokenizer.eos_token_id, | |
| use_cache=True | |
| ) | |
| else: | |
| raise | |
| # Decode only the new tokens (skip input) | |
| input_length = inputs['input_ids'].shape[1] | |
| new_tokens = generated_ids[0][input_length:] | |
| generated_text = self.processor.tokenizer.decode( | |
| new_tokens, | |
| skip_special_tokens=True | |
| ).strip() | |
| # Return meaningful response even if empty | |
| if not generated_text: | |
| return "I can see the image but cannot generate a specific description." | |
| return generated_text | |
| except Exception as e: | |
| logger.error(f"β Error during generation: {e}") | |
| raise | |
| def analyze_video_frames( | |
| self, | |
| frames: List[Image.Image], | |
| question: str, | |
| max_frames: int = 8 | |
| ) -> str: | |
| """ | |
| Analyze video frames and answer questions | |
| Args: | |
| frames: List of PIL Image frames | |
| question: Question about the video | |
| max_frames: Maximum number of frames to process | |
| Returns: | |
| Analysis result | |
| """ | |
| # Sample frames if too many | |
| if len(frames) > max_frames: | |
| step = len(frames) // max_frames | |
| sampled_frames = frames[::step][:max_frames] | |
| else: | |
| sampled_frames = frames | |
| logger.info(f"Analyzing {len(sampled_frames)} frames") | |
| # Create a simple prompt for video analysis (don't add image tokens manually) | |
| video_prompt = f"These are frames from a video. {question}" | |
| return self.generate_response(sampled_frames, video_prompt) | |
| def get_model_info(self) -> dict: | |
| """Get information about the loaded model""" | |
| return { | |
| "model_name": self.model_name, | |
| "device": self.device, | |
| "model_type": type(self.model).__name__, | |
| "processor_type": type(self.processor).__name__, | |
| "loaded": self.model is not None and self.processor is not None | |
| } | |
| def test_model(): | |
| """Test the model with a simple example""" | |
| try: | |
| # Initialize model | |
| vlm = SmolVLM2Handler() | |
| print("π Model Info:") | |
| info = vlm.get_model_info() | |
| for key, value in info.items(): | |
| print(f" {key}: {value}") | |
| # Test with a simple image (create a test image) | |
| test_image = Image.new('RGB', (224, 224), color='blue') | |
| test_prompt = "What color is this image?" | |
| print(f"\nπ Testing with prompt: '{test_prompt}'") | |
| response = vlm.generate_response(test_image, test_prompt) | |
| print(f"π Response: {response}") | |
| print("\nβ Model test completed successfully!") | |
| except Exception as e: | |
| print(f"β Model test failed: {e}") | |
| raise | |
| if __name__ == "__main__": | |
| test_model() | |