File size: 5,182 Bytes
cce2b06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""

CLIP-based Quality Validation System

Validates generated images for prompt alignment and quality

"""

import torch
from PIL import Image
from typing import Dict, Optional
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")


class QualityValidator:
    """

    Uses CLIP to validate prompt-image alignment and quality

    """
    
    def __init__(self, device: Optional[str] = None):
        """

        Initialize the quality validator

        

        Args:

            device: Device to run validation on ('cuda' or 'cpu')

        """
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model = None
        self.processor = None
        self._initialized = False
    
    def _lazy_load(self):
        """

        Lazy load CLIP model to save memory

        """
        if self._initialized:
            return
        
        try:
            from transformers import CLIPProcessor, CLIPModel
            
            print("Loading CLIP model for quality validation...")
            self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
            self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
            self._initialized = True
            print("CLIP model loaded successfully!")
            
        except Exception as e:
            print(f"Warning: Could not load CLIP model: {e}")
            print("Quality validation will be disabled.")
            self._initialized = False
    
    def validate(

        self, 

        image: Image.Image, 

        prompt: str,

        return_details: bool = False

    ) -> float:
        """

        Validate image quality and prompt alignment

        

        Args:

            image: Generated PIL Image

            prompt: Original prompt used for generation

            return_details: Whether to return detailed validation info

            

        Returns:

            Quality score between 0 and 1 (higher is better)

        """
        self._lazy_load()
        
        if not self._initialized:
            # Return neutral score if CLIP is not available
            return 0.5
        
        try:
            # Prepare inputs
            inputs = self.processor(
                text=[prompt],
                images=image,
                return_tensors="pt",
                padding=True
            ).to(self.device)
            
            # Get CLIP similarity score
            with torch.no_grad():
                outputs = self.model(**inputs)
                logits_per_image = outputs.logits_per_image
                # Convert to probability (0-1 range)
                probs = logits_per_image.softmax(dim=1)
                score = probs[0][0].item()
            
            return score
            
        except Exception as e:
            print(f"Validation error: {e}")
            return 0.5  # Return neutral score on error
    
    def get_quality_feedback(self, score: float) -> str:
        """

        Get human-readable feedback based on quality score

        

        Args:

            score: Quality score from validation

            

        Returns:

            Feedback string

        """
        if score >= 0.35:
            return "✅ Excellent - Strong prompt alignment"
        elif score >= 0.28:
            return "✓ Good - Acceptable quality"
        elif score >= 0.20:
            return "⚠ Fair - Consider refining prompt"
        else:
            return "❌ Poor - Prompt may need significant changes"
    
    def should_retry(self, score: float, min_score: float = 0.25) -> bool:
        """

        Determine if image should be regenerated based on quality

        

        Args:

            score: Quality score

            min_score: Minimum acceptable score

            

        Returns:

            True if should retry generation

        """
        return score < min_score
    
    def batch_validate(

        self, 

        images: list,

        prompts: list

    ) -> list:
        """

        Validate multiple images at once

        

        Args:

            images: List of PIL Images

            prompts: List of corresponding prompts

            

        Returns:

            List of quality scores

        """
        scores = []
        for image, prompt in zip(images, prompts):
            score = self.validate(image, prompt)
            scores.append(score)
        return scores


# Test function
if __name__ == "__main__":
    print("=== Quality Validator Test ===\n")
    
    validator = QualityValidator()
    
    # Create a simple test image
    test_image = Image.new('RGB', (512, 512), color='red')
    test_prompt = "a red square"
    
    print(f"Testing with prompt: '{test_prompt}'")
    score = validator.validate(test_image, test_prompt)
    print(f"Quality Score: {score:.4f}")
    print(f"Feedback: {validator.get_quality_feedback(score)}")
    print(f"Should Retry: {validator.should_retry(score)}")