Spaces:
Sleeping
Sleeping
| """ | |
| CLIP-based Quality Validation System | |
| Validates generated images for prompt alignment and quality | |
| """ | |
| import torch | |
| from PIL import Image | |
| from typing import Dict, Optional | |
| import warnings | |
| # Suppress warnings for cleaner output | |
| warnings.filterwarnings("ignore") | |
| class QualityValidator: | |
| """ | |
| Uses CLIP to validate prompt-image alignment and quality | |
| """ | |
| def __init__(self, device: Optional[str] = None): | |
| """ | |
| Initialize the quality validator | |
| Args: | |
| device: Device to run validation on ('cuda' or 'cpu') | |
| """ | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = None | |
| self.processor = None | |
| self._initialized = False | |
| def _lazy_load(self): | |
| """ | |
| Lazy load CLIP model to save memory | |
| """ | |
| if self._initialized: | |
| return | |
| try: | |
| from transformers import CLIPProcessor, CLIPModel | |
| print("Loading CLIP model for quality validation...") | |
| self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) | |
| self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| self._initialized = True | |
| print("CLIP model loaded successfully!") | |
| except Exception as e: | |
| print(f"Warning: Could not load CLIP model: {e}") | |
| print("Quality validation will be disabled.") | |
| self._initialized = False | |
| def validate( | |
| self, | |
| image: Image.Image, | |
| prompt: str, | |
| return_details: bool = False | |
| ) -> float: | |
| """ | |
| Validate image quality and prompt alignment | |
| Args: | |
| image: Generated PIL Image | |
| prompt: Original prompt used for generation | |
| return_details: Whether to return detailed validation info | |
| Returns: | |
| Quality score between 0 and 1 (higher is better) | |
| """ | |
| self._lazy_load() | |
| if not self._initialized: | |
| # Return neutral score if CLIP is not available | |
| return 0.5 | |
| try: | |
| # Prepare inputs | |
| inputs = self.processor( | |
| text=[prompt], | |
| images=image, | |
| return_tensors="pt", | |
| padding=True | |
| ).to(self.device) | |
| # Get CLIP similarity score | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| # Convert to probability (0-1 range) | |
| probs = logits_per_image.softmax(dim=1) | |
| score = probs[0][0].item() | |
| return score | |
| except Exception as e: | |
| print(f"Validation error: {e}") | |
| return 0.5 # Return neutral score on error | |
| def get_quality_feedback(self, score: float) -> str: | |
| """ | |
| Get human-readable feedback based on quality score | |
| Args: | |
| score: Quality score from validation | |
| Returns: | |
| Feedback string | |
| """ | |
| if score >= 0.35: | |
| return "✅ Excellent - Strong prompt alignment" | |
| elif score >= 0.28: | |
| return "✓ Good - Acceptable quality" | |
| elif score >= 0.20: | |
| return "⚠ Fair - Consider refining prompt" | |
| else: | |
| return "❌ Poor - Prompt may need significant changes" | |
| def should_retry(self, score: float, min_score: float = 0.25) -> bool: | |
| """ | |
| Determine if image should be regenerated based on quality | |
| Args: | |
| score: Quality score | |
| min_score: Minimum acceptable score | |
| Returns: | |
| True if should retry generation | |
| """ | |
| return score < min_score | |
| def batch_validate( | |
| self, | |
| images: list, | |
| prompts: list | |
| ) -> list: | |
| """ | |
| Validate multiple images at once | |
| Args: | |
| images: List of PIL Images | |
| prompts: List of corresponding prompts | |
| Returns: | |
| List of quality scores | |
| """ | |
| scores = [] | |
| for image, prompt in zip(images, prompts): | |
| score = self.validate(image, prompt) | |
| scores.append(score) | |
| return scores | |
| # Test function | |
| if __name__ == "__main__": | |
| print("=== Quality Validator Test ===\n") | |
| validator = QualityValidator() | |
| # Create a simple test image | |
| test_image = Image.new('RGB', (512, 512), color='red') | |
| test_prompt = "a red square" | |
| print(f"Testing with prompt: '{test_prompt}'") | |
| score = validator.validate(test_image, test_prompt) | |
| print(f"Quality Score: {score:.4f}") | |
| print(f"Feedback: {validator.get_quality_feedback(score)}") | |
| print(f"Should Retry: {validator.should_retry(score)}") | |