""" CLIP-based Quality Validation System Validates generated images for prompt alignment and quality """ import torch from PIL import Image from typing import Dict, Optional import warnings # Suppress warnings for cleaner output warnings.filterwarnings("ignore") class QualityValidator: """ Uses CLIP to validate prompt-image alignment and quality """ def __init__(self, device: Optional[str] = None): """ Initialize the quality validator Args: device: Device to run validation on ('cuda' or 'cpu') """ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.model = None self.processor = None self._initialized = False def _lazy_load(self): """ Lazy load CLIP model to save memory """ if self._initialized: return try: from transformers import CLIPProcessor, CLIPModel print("Loading CLIP model for quality validation...") self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") self._initialized = True print("CLIP model loaded successfully!") except Exception as e: print(f"Warning: Could not load CLIP model: {e}") print("Quality validation will be disabled.") self._initialized = False def validate( self, image: Image.Image, prompt: str, return_details: bool = False ) -> float: """ Validate image quality and prompt alignment Args: image: Generated PIL Image prompt: Original prompt used for generation return_details: Whether to return detailed validation info Returns: Quality score between 0 and 1 (higher is better) """ self._lazy_load() if not self._initialized: # Return neutral score if CLIP is not available return 0.5 try: # Prepare inputs inputs = self.processor( text=[prompt], images=image, return_tensors="pt", padding=True ).to(self.device) # Get CLIP similarity score with torch.no_grad(): outputs = self.model(**inputs) logits_per_image = outputs.logits_per_image # Convert to probability (0-1 range) probs = logits_per_image.softmax(dim=1) score = probs[0][0].item() return score except Exception as e: print(f"Validation error: {e}") return 0.5 # Return neutral score on error def get_quality_feedback(self, score: float) -> str: """ Get human-readable feedback based on quality score Args: score: Quality score from validation Returns: Feedback string """ if score >= 0.35: return "✅ Excellent - Strong prompt alignment" elif score >= 0.28: return "✓ Good - Acceptable quality" elif score >= 0.20: return "⚠ Fair - Consider refining prompt" else: return "❌ Poor - Prompt may need significant changes" def should_retry(self, score: float, min_score: float = 0.25) -> bool: """ Determine if image should be regenerated based on quality Args: score: Quality score min_score: Minimum acceptable score Returns: True if should retry generation """ return score < min_score def batch_validate( self, images: list, prompts: list ) -> list: """ Validate multiple images at once Args: images: List of PIL Images prompts: List of corresponding prompts Returns: List of quality scores """ scores = [] for image, prompt in zip(images, prompts): score = self.validate(image, prompt) scores.append(score) return scores # Test function if __name__ == "__main__": print("=== Quality Validator Test ===\n") validator = QualityValidator() # Create a simple test image test_image = Image.new('RGB', (512, 512), color='red') test_prompt = "a red square" print(f"Testing with prompt: '{test_prompt}'") score = validator.validate(test_image, test_prompt) print(f"Quality Score: {score:.4f}") print(f"Feedback: {validator.get_quality_feedback(score)}") print(f"Should Retry: {validator.should_retry(score)}")