Spaces:
Runtime error
Runtime error
| """ | |
| Model management for Frame 0 Laboratory for MIA | |
| BAGEL 7B integration via API calls | |
| """ | |
| import spaces | |
| import logging | |
| import tempfile | |
| import os | |
| from typing import Optional, Dict, Any, Tuple | |
| from PIL import Image | |
| from gradio_client import Client, handle_file | |
| from config import get_device_config | |
| from utils import clean_memory, safe_execute | |
| logger = logging.getLogger(__name__) | |
| class BaseImageAnalyzer: | |
| """Base class for image analysis models""" | |
| def __init__(self): | |
| self.is_initialized = False | |
| self.device_config = get_device_config() | |
| def initialize(self) -> bool: | |
| """Initialize the model""" | |
| raise NotImplementedError | |
| def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
| """Analyze image and return description""" | |
| raise NotImplementedError | |
| def cleanup(self) -> None: | |
| """Clean up model resources""" | |
| clean_memory() | |
| class BagelAPIAnalyzer(BaseImageAnalyzer): | |
| """BAGEL 7B model via API calls to working Space""" | |
| def __init__(self): | |
| super().__init__() | |
| self.client = None | |
| self.space_url = "Malaji71/Bagel-7B-Demo" | |
| self.api_endpoint = "/image_understanding" | |
| def initialize(self) -> bool: | |
| """Initialize BAGEL API client""" | |
| if self.is_initialized: | |
| return True | |
| try: | |
| logger.info("Initializing BAGEL API client...") | |
| self.client = Client(self.space_url) | |
| self.is_initialized = True | |
| logger.info("BAGEL API client initialized successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"BAGEL API client initialization failed: {e}") | |
| return False | |
| def _extract_camera_setup(self, description: str) -> Optional[str]: | |
| """Extract camera setup recommendation from BAGEL response""" | |
| try: | |
| # Look for CAMERA_SETUP section | |
| if "CAMERA_SETUP:" in description: | |
| parts = description.split("CAMERA_SETUP:") | |
| if len(parts) > 1: | |
| camera_part = parts[1].strip() | |
| # Clean up any additional formatting | |
| camera_part = camera_part.replace("\n", " ").strip() | |
| return camera_part | |
| # Alternative patterns for camera recommendations | |
| camera_patterns = [ | |
| "Shot on ", | |
| "Camera: ", | |
| "Equipment: ", | |
| "Recommended camera:", | |
| "Camera setup:" | |
| ] | |
| for pattern in camera_patterns: | |
| if pattern in description: | |
| # Extract text after the pattern | |
| idx = description.find(pattern) | |
| camera_text = description[idx:].split('.')[0] # Take first sentence | |
| if len(camera_text) > len(pattern) + 10: # Ensure meaningful content | |
| return camera_text.strip() | |
| return None | |
| except Exception as e: | |
| logger.warning(f"Failed to extract camera setup: {e}") | |
| return None | |
| def _save_temp_image(self, image: Image.Image) -> str: | |
| """Save image to temporary file for API call""" | |
| try: | |
| # Create temporary file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') | |
| temp_path = temp_file.name | |
| temp_file.close() | |
| # Save image | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image.save(temp_path, 'PNG') | |
| return temp_path | |
| except Exception as e: | |
| logger.error(f"Failed to save temporary image: {e}") | |
| return None | |
| def _cleanup_temp_file(self, file_path: str): | |
| """Clean up temporary file""" | |
| try: | |
| if file_path and os.path.exists(file_path): | |
| os.unlink(file_path) | |
| except Exception as e: | |
| logger.warning(f"Failed to cleanup temp file: {e}") | |
| def analyze_image(self, image: Image.Image, prompt: str = None) -> Tuple[str, Dict[str, Any]]: | |
| """Analyze image using BAGEL API""" | |
| if not self.is_initialized: | |
| success = self.initialize() | |
| if not success: | |
| return "BAGEL API not available", {"error": "API initialization failed"} | |
| temp_path = None | |
| # Initialize metadata early | |
| metadata = { | |
| "model": "BAGEL-7B-API", | |
| "device": "api", | |
| "confidence": 0.9, | |
| "api_endpoint": self.api_endpoint, | |
| "space_url": self.space_url, | |
| "prompt_used": prompt, | |
| "has_camera_suggestion": False | |
| } | |
| try: | |
| # Default prompt for detailed image analysis | |
| if prompt is None: | |
| prompt = """You are analyzing a photograph for FLUX image generation. Provide a detailed analysis in two sections: | |
| 1. DESCRIPTION: Start directly with the subject (e.g., "A color photograph showing..." or "A black and white photograph depicting..."). First, determine if this is a photograph, illustration, or artwork. Then describe the visual elements, composition, lighting, colors (be specific about the color palette - warm tones, cool tones, monochrome, etc.), artistic style, mood, and atmosphere. Also mention the image format/aspect ratio (square, portrait, landscape, widescreen, etc.) and how the composition uses this format. Write as a flowing paragraph without numbered lists. | |
| 2. CAMERA_SETUP: Based on the photographic characteristics, scene type, and aspect ratio you observe, recommend the specific camera system and lens that would realistically capture this type of scene: | |
| - For street/documentary photography: suggest cameras like Canon EOS R6, Sony A7 IV, Leica Q2 with 35mm or 24-70mm lenses | |
| - For portraits: suggest cameras like Canon EOS R5, Sony A7R V with 85mm or 135mm lenses | |
| - For landscapes/widescreen: suggest cameras like Phase One XT, Fujifilm GFX with wide-angle lenses (16-35mm, 24-70mm) | |
| - For sports/action: suggest cameras like Canon EOS-1D X, Sony A9 III with telephoto lenses | |
| - For macro: suggest specialized macro lenses | |
| - For cinematic/widescreen formats: suggest cinema cameras or full-frame with appropriate aspect ratios | |
| Be specific about focal length, aperture, and shooting style based on what you actually see in the image dimensions and content. | |
| Analyze carefully and be accurate about colors, image type, and proportions.""" | |
| # Save image to temporary file | |
| temp_path = self._save_temp_image(image) | |
| if not temp_path: | |
| return "Image processing failed", {"error": "Could not save image"} | |
| logger.info("Calling BAGEL API for image analysis...") | |
| # Call BAGEL API | |
| result = self.client.predict( | |
| image=handle_file(temp_path), | |
| prompt=prompt, | |
| show_thinking=False, | |
| do_sample=False, | |
| text_temperature=0.3, | |
| max_new_tokens=512, | |
| api_name=self.api_endpoint | |
| ) | |
| # Extract response (API returns tuple: (image_result, text_response)) | |
| if isinstance(result, tuple) and len(result) >= 2: | |
| description = result[1] if result[1] else result[0] | |
| else: | |
| description = str(result) | |
| # Clean up the description and extract camera setup if present | |
| if isinstance(description, str) and description.strip(): | |
| description = description.strip() | |
| # Store camera setup separately if found | |
| camera_setup = self._extract_camera_setup(description) | |
| if camera_setup: | |
| metadata["camera_setup"] = camera_setup | |
| metadata["has_camera_suggestion"] = True | |
| else: | |
| metadata["has_camera_suggestion"] = False | |
| else: | |
| description = "Detailed image analysis completed successfully" | |
| metadata["has_camera_suggestion"] = False | |
| # Update final metadata | |
| metadata.update({ | |
| "response_length": len(description) | |
| }) | |
| logger.info(f"BAGEL API analysis complete: {len(description)} characters") | |
| return description, metadata | |
| except Exception as e: | |
| logger.error(f"BAGEL API analysis failed: {e}") | |
| return "API analysis failed", {"error": str(e), "model": "BAGEL-7B-API"} | |
| finally: | |
| # Always cleanup temporary file | |
| if temp_path: | |
| self._cleanup_temp_file(temp_path) | |
| def analyze_for_flux_prompt(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
| """Analyze image specifically for FLUX prompt generation""" | |
| flux_prompt = """You are analyzing a photograph for professional FLUX generation. Provide two sections: | |
| 1. DESCRIPTION: Determine first if this is a real photograph, digital artwork, or illustration. Then create a detailed, flowing description starting directly with the subject. Be precise about: | |
| - Image type (photograph, illustration, artwork) | |
| - Color palette (specify if color or black/white, warm/cool tones, specific colors) | |
| - Photographic style (street, portrait, landscape, documentary, artistic, etc.) | |
| - Composition, lighting, mood, and atmosphere | |
| Write as a single coherent paragraph. | |
| 2. CAMERA_SETUP: Recommend specific professional equipment that would realistically capture this exact scene: | |
| - Street/urban scenes: Canon EOS R6, Sony A7 IV, Leica Q2 with 24-70mm f/2.8 or 35mm f/1.4 | |
| - Portraits: Canon EOS R5, Sony A7R V, Hasselblad X2D with 85mm f/1.4 or 135mm f/2 | |
| - Landscapes: Phase One XT, Fujifilm GFX 100S with 16-35mm f/2.8 or 40mm f/4 | |
| - Documentary: Canon EOS-1D X, Sony A9 III with 24-105mm f/4 or 70-200mm f/2.8 | |
| - Action/Sports: Canon EOS R3, Sony A1 with 300mm f/2.8 or 400mm f/2.8 | |
| Match the equipment to what you actually observe in the scene type and shooting conditions.""" | |
| return self.analyze_image(image, flux_prompt) | |
| def cleanup(self) -> None: | |
| """Clean up API client resources""" | |
| try: | |
| if hasattr(self, 'client'): | |
| self.client = None | |
| super().cleanup() | |
| logger.info("BAGEL API resources cleaned up") | |
| except Exception as e: | |
| logger.warning(f"BAGEL API cleanup warning: {e}") | |
| class FallbackAnalyzer(BaseImageAnalyzer): | |
| """Simple fallback analyzer when BAGEL API is not available""" | |
| def __init__(self): | |
| super().__init__() | |
| def initialize(self) -> bool: | |
| """Fallback is always ready""" | |
| self.is_initialized = True | |
| return True | |
| def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
| """Provide basic image description""" | |
| try: | |
| # Basic image analysis | |
| width, height = image.size | |
| mode = image.mode | |
| # Simple descriptive text based on image properties | |
| aspect_ratio = width / height | |
| if aspect_ratio > 1.5: | |
| orientation = "landscape" | |
| camera_suggestion = "wide-angle lens, landscape photography" | |
| elif aspect_ratio < 0.75: | |
| orientation = "portrait" | |
| camera_suggestion = "portrait lens, shallow depth of field" | |
| else: | |
| orientation = "square" | |
| camera_suggestion = "standard lens, balanced composition" | |
| description = f"A {orientation} format image with professional composition. The image shows clear detail and good visual balance, suitable for high-quality reproduction. Recommended camera setup: {camera_suggestion}, professional lighting with careful attention to exposure and color balance." | |
| metadata = { | |
| "model": "Fallback", | |
| "device": "cpu", | |
| "confidence": 0.6, | |
| "image_size": f"{width}x{height}", | |
| "color_mode": mode, | |
| "orientation": orientation, | |
| "aspect_ratio": round(aspect_ratio, 2) | |
| } | |
| return description, metadata | |
| except Exception as e: | |
| logger.error(f"Fallback analysis failed: {e}") | |
| return "Professional image suitable for detailed analysis and prompt generation", {"error": str(e), "model": "Fallback"} | |
| class ModelManager: | |
| """Manager for handling image analysis models""" | |
| def __init__(self, preferred_model: str = "bagel-api"): | |
| self.preferred_model = preferred_model | |
| self.analyzers = {} | |
| self.current_analyzer = None | |
| def get_analyzer(self, model_name: str = None) -> Optional[BaseImageAnalyzer]: | |
| """Get or create analyzer for specified model""" | |
| model_name = model_name or self.preferred_model | |
| if model_name not in self.analyzers: | |
| if model_name == "bagel-api": | |
| self.analyzers[model_name] = BagelAPIAnalyzer() | |
| elif model_name == "fallback": | |
| self.analyzers[model_name] = FallbackAnalyzer() | |
| else: | |
| logger.warning(f"Unknown model: {model_name}, using fallback") | |
| model_name = "fallback" | |
| self.analyzers[model_name] = FallbackAnalyzer() | |
| return self.analyzers[model_name] | |
| def analyze_image(self, image: Image.Image, model_name: str = None, analysis_type: str = "detailed") -> Tuple[str, Dict[str, Any]]: | |
| """Analyze image with specified or preferred model""" | |
| # Try preferred model first | |
| analyzer = self.get_analyzer(model_name) | |
| if analyzer is None: | |
| return "No analyzer available", {"error": "Model not found"} | |
| # Choose analysis method based on type | |
| if analysis_type == "flux" and hasattr(analyzer, 'analyze_for_flux_prompt'): | |
| success, result = safe_execute(analyzer.analyze_for_flux_prompt, image) | |
| else: | |
| success, result = safe_execute(analyzer.analyze_image, image) | |
| if success and result[1].get("error") is None: | |
| return result | |
| else: | |
| # Fallback to simple analyzer if main model fails | |
| logger.warning(f"Primary model failed, using fallback: {result}") | |
| fallback_analyzer = self.get_analyzer("fallback") | |
| fallback_success, fallback_result = safe_execute(fallback_analyzer.analyze_image, image) | |
| if fallback_success: | |
| return fallback_result | |
| else: | |
| return "All analyzers failed", {"error": "Complete analysis failure"} | |
| def cleanup_all(self) -> None: | |
| """Clean up all model resources""" | |
| for analyzer in self.analyzers.values(): | |
| analyzer.cleanup() | |
| self.analyzers.clear() | |
| clean_memory() | |
| logger.info("All analyzers cleaned up") | |
| # Global model manager instance | |
| model_manager = ModelManager(preferred_model="bagel-api") | |
| def analyze_image(image: Image.Image, model_name: str = None, analysis_type: str = "detailed") -> Tuple[str, Dict[str, Any]]: | |
| """ | |
| Convenience function for image analysis using BAGEL API | |
| Args: | |
| image: PIL Image to analyze | |
| model_name: Optional model name ("bagel-api" or "fallback") | |
| analysis_type: Type of analysis ("detailed" or "flux") | |
| Returns: | |
| Tuple of (description, metadata) | |
| """ | |
| return model_manager.analyze_image(image, model_name, analysis_type) | |
| # Export main components | |
| __all__ = [ | |
| "BaseImageAnalyzer", | |
| "BagelAPIAnalyzer", | |
| "FallbackAnalyzer", | |
| "ModelManager", | |
| "model_manager", | |
| "analyze_image" | |
| ] |