Spaces:
Running
Running
File size: 5,182 Bytes
cce2b06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
"""
CLIP-based Quality Validation System
Validates generated images for prompt alignment and quality
"""
import torch
from PIL import Image
from typing import Dict, Optional
import warnings
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")
class QualityValidator:
"""
Uses CLIP to validate prompt-image alignment and quality
"""
def __init__(self, device: Optional[str] = None):
"""
Initialize the quality validator
Args:
device: Device to run validation on ('cuda' or 'cpu')
"""
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.processor = None
self._initialized = False
def _lazy_load(self):
"""
Lazy load CLIP model to save memory
"""
if self._initialized:
return
try:
from transformers import CLIPProcessor, CLIPModel
print("Loading CLIP model for quality validation...")
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self._initialized = True
print("CLIP model loaded successfully!")
except Exception as e:
print(f"Warning: Could not load CLIP model: {e}")
print("Quality validation will be disabled.")
self._initialized = False
def validate(
self,
image: Image.Image,
prompt: str,
return_details: bool = False
) -> float:
"""
Validate image quality and prompt alignment
Args:
image: Generated PIL Image
prompt: Original prompt used for generation
return_details: Whether to return detailed validation info
Returns:
Quality score between 0 and 1 (higher is better)
"""
self._lazy_load()
if not self._initialized:
# Return neutral score if CLIP is not available
return 0.5
try:
# Prepare inputs
inputs = self.processor(
text=[prompt],
images=image,
return_tensors="pt",
padding=True
).to(self.device)
# Get CLIP similarity score
with torch.no_grad():
outputs = self.model(**inputs)
logits_per_image = outputs.logits_per_image
# Convert to probability (0-1 range)
probs = logits_per_image.softmax(dim=1)
score = probs[0][0].item()
return score
except Exception as e:
print(f"Validation error: {e}")
return 0.5 # Return neutral score on error
def get_quality_feedback(self, score: float) -> str:
"""
Get human-readable feedback based on quality score
Args:
score: Quality score from validation
Returns:
Feedback string
"""
if score >= 0.35:
return "✅ Excellent - Strong prompt alignment"
elif score >= 0.28:
return "✓ Good - Acceptable quality"
elif score >= 0.20:
return "⚠ Fair - Consider refining prompt"
else:
return "❌ Poor - Prompt may need significant changes"
def should_retry(self, score: float, min_score: float = 0.25) -> bool:
"""
Determine if image should be regenerated based on quality
Args:
score: Quality score
min_score: Minimum acceptable score
Returns:
True if should retry generation
"""
return score < min_score
def batch_validate(
self,
images: list,
prompts: list
) -> list:
"""
Validate multiple images at once
Args:
images: List of PIL Images
prompts: List of corresponding prompts
Returns:
List of quality scores
"""
scores = []
for image, prompt in zip(images, prompts):
score = self.validate(image, prompt)
scores.append(score)
return scores
# Test function
if __name__ == "__main__":
print("=== Quality Validator Test ===\n")
validator = QualityValidator()
# Create a simple test image
test_image = Image.new('RGB', (512, 512), color='red')
test_prompt = "a red square"
print(f"Testing with prompt: '{test_prompt}'")
score = validator.validate(test_image, test_prompt)
print(f"Quality Score: {score:.4f}")
print(f"Feedback: {validator.get_quality_feedback(score)}")
print(f"Should Retry: {validator.should_retry(score)}")
|