ai-image-generator / quality_validator.py
Robin7339's picture
Upload 6 files
cce2b06 verified
"""
CLIP-based Quality Validation System
Validates generated images for prompt alignment and quality
"""
import torch
from PIL import Image
from typing import Dict, Optional
import warnings
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")
class QualityValidator:
"""
Uses CLIP to validate prompt-image alignment and quality
"""
def __init__(self, device: Optional[str] = None):
"""
Initialize the quality validator
Args:
device: Device to run validation on ('cuda' or 'cpu')
"""
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.processor = None
self._initialized = False
def _lazy_load(self):
"""
Lazy load CLIP model to save memory
"""
if self._initialized:
return
try:
from transformers import CLIPProcessor, CLIPModel
print("Loading CLIP model for quality validation...")
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self._initialized = True
print("CLIP model loaded successfully!")
except Exception as e:
print(f"Warning: Could not load CLIP model: {e}")
print("Quality validation will be disabled.")
self._initialized = False
def validate(
self,
image: Image.Image,
prompt: str,
return_details: bool = False
) -> float:
"""
Validate image quality and prompt alignment
Args:
image: Generated PIL Image
prompt: Original prompt used for generation
return_details: Whether to return detailed validation info
Returns:
Quality score between 0 and 1 (higher is better)
"""
self._lazy_load()
if not self._initialized:
# Return neutral score if CLIP is not available
return 0.5
try:
# Prepare inputs
inputs = self.processor(
text=[prompt],
images=image,
return_tensors="pt",
padding=True
).to(self.device)
# Get CLIP similarity score
with torch.no_grad():
outputs = self.model(**inputs)
logits_per_image = outputs.logits_per_image
# Convert to probability (0-1 range)
probs = logits_per_image.softmax(dim=1)
score = probs[0][0].item()
return score
except Exception as e:
print(f"Validation error: {e}")
return 0.5 # Return neutral score on error
def get_quality_feedback(self, score: float) -> str:
"""
Get human-readable feedback based on quality score
Args:
score: Quality score from validation
Returns:
Feedback string
"""
if score >= 0.35:
return "✅ Excellent - Strong prompt alignment"
elif score >= 0.28:
return "✓ Good - Acceptable quality"
elif score >= 0.20:
return "⚠ Fair - Consider refining prompt"
else:
return "❌ Poor - Prompt may need significant changes"
def should_retry(self, score: float, min_score: float = 0.25) -> bool:
"""
Determine if image should be regenerated based on quality
Args:
score: Quality score
min_score: Minimum acceptable score
Returns:
True if should retry generation
"""
return score < min_score
def batch_validate(
self,
images: list,
prompts: list
) -> list:
"""
Validate multiple images at once
Args:
images: List of PIL Images
prompts: List of corresponding prompts
Returns:
List of quality scores
"""
scores = []
for image, prompt in zip(images, prompts):
score = self.validate(image, prompt)
scores.append(score)
return scores
# Test function
if __name__ == "__main__":
print("=== Quality Validator Test ===\n")
validator = QualityValidator()
# Create a simple test image
test_image = Image.new('RGB', (512, 512), color='red')
test_prompt = "a red square"
print(f"Testing with prompt: '{test_prompt}'")
score = validator.validate(test_image, test_prompt)
print(f"Quality Score: {score:.4f}")
print(f"Feedback: {validator.get_quality_feedback(score)}")
print(f"Should Retry: {validator.should_retry(score)}")