""" Image Evaluator Core Logic Contains the main evaluation classes: - ImageEvaluator: For text-to-image generation quality assessment - EditEvaluator: For image editing quality assessment """ import re import math import time from typing import Optional, List, Dict, Any from dataclasses import dataclass, field from PIL import Image from metrics import ( parse_json_robust, calculate_sharpness, calculate_colorfulness, calculate_contrast, calculate_ssim, calculate_psnr, calculate_clip_score, calculate_lpips, score_to_grade, geometric_mean, ) @dataclass class PrimitiveResult: """Result for a single Soft-TIFA primitive.""" content: str type: str question: str answer: str score: float reasoning: Optional[str] = None @dataclass class SoftTIFAResult: """Soft-TIFA evaluation result.""" primitives_count: int atom_score: float prompt_score: float passed: bool primitive_results: List[PrimitiveResult] @dataclass class VLMAssessmentResult: """VLM-as-Judge assessment result.""" technical_quality: float aesthetic_appeal: float realism: float semantic_accuracy: Optional[float] artifacts_detected: List[str] artifacts_severity: str overall: float reasoning: Optional[str] = None @dataclass class TechnicalMetricsResult: """Technical metrics result.""" clip_score: Optional[float] = None sharpness: Optional[float] = None colorfulness: Optional[float] = None contrast: Optional[float] = None @dataclass class ScoreBreakdown: """Detailed score breakdown by category.""" prompt_alignment: Optional[float] = None technical_quality: Optional[float] = None aesthetic_appeal: Optional[float] = None realism: Optional[float] = None artifacts: Optional[float] = None @dataclass class AggregatedScore: """Comprehensive aggregated scoring.""" overall: float grade: str passed: bool confidence: float breakdown: ScoreBreakdown weights_used: Dict[str, float] recommendation: str @dataclass class ImageEvalResult: """Complete image evaluation result.""" score: AggregatedScore soft_tifa: Optional[SoftTIFAResult] = None vlm_assessment: Optional[VLMAssessmentResult] = None technical_metrics: Optional[TechnicalMetricsResult] = None evaluation_time: float = 0.0 @dataclass class InstructionFollowingResult: """Instruction following evaluation result.""" edit_primitives: List[Dict] primitive_scores: List[Dict] overall_score: float reasoning: Optional[str] = None @dataclass class PreservationResult: """Preservation evaluation result.""" lpips_score: Optional[float] = None ssim_score: Optional[float] = None psnr_score: Optional[float] = None overall_score: float = 0.0 @dataclass class EditQualityResult: """Edit quality assessment result.""" technical_score: float aesthetic_score: float coherence_score: float artifacts: List[str] artifact_severity: str overall_score: float reasoning: Optional[str] = None @dataclass class EditScoreBreakdown: """Detailed score breakdown for editing evaluation.""" instruction_following: Optional[float] = None preservation: Optional[float] = None edit_quality: Optional[float] = None artifacts: Optional[float] = None @dataclass class EditAggregatedScore: """Comprehensive aggregated scoring for editing.""" overall: float grade: str passed: bool confidence: float breakdown: EditScoreBreakdown weights_used: Dict[str, float] recommendation: str @dataclass class EditEvalResult: """Complete edit evaluation result.""" score: EditAggregatedScore instruction_following: Optional[InstructionFollowingResult] = None preservation: Optional[PreservationResult] = None edit_quality: Optional[EditQualityResult] = None evaluation_time: float = 0.0 @dataclass class ComparisonRanking: """Ranking result for a single criterion.""" criterion: str ranking: List[int] # Image indices, best to worst scores: List[float] # Normalized scores per image reasoning: Optional[str] = None @dataclass class ComparisonResult: """Complete comparison result for multiple images.""" num_images: int prompt: str overall_ranking: List[int] # Image indices, best to worst overall_scores: List[float] # Normalized scores per image rankings_by_criterion: Dict[str, ComparisonRanking] winner_index: int winner_reasoning: str individual_scores: List[AggregatedScore] # Individual evaluation scores individual_results: List["ImageEvalResult"] = field(default_factory=list) # Full evaluation results evaluation_time: float = 0.0 class ImageEvaluator: """ AI-Generated Image Quality Evaluator Evaluates AI-generated images using: - Soft-TIFA: Atomic prompt decomposition for precise alignment scoring - VLM-as-Judge: Human-like holistic assessment with reasoning - Technical Metrics: Sharpness, colorfulness, contrast, CLIP score """ def __init__(self, device: str = "cuda"): """Initialize evaluator with models.""" import torch from transformers import AutoModelForImageTextToText, AutoProcessor self.device = device if torch.cuda.is_available() else "cpu" # Load Qwen2.5-VL-7B-Instruct model_name = "Qwen/Qwen2.5-VL-7B-Instruct" self.vlm_model = AutoModelForImageTextToText.from_pretrained( model_name, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, ) self.vlm_processor = AutoProcessor.from_pretrained( model_name, trust_remote_code=True, ) # Load CLIP for text-image alignment import open_clip self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( 'ViT-B-32', pretrained='openai' ) self.clip_model = self.clip_model.to(self.device).eval() self.clip_tokenizer = open_clip.get_tokenizer('ViT-B-32') def _vlm_generate(self, image: Image.Image, prompt: str) -> str: """Generate response from VLM with image.""" import torch messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], } ] text = self.vlm_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = self.vlm_processor( text=[text], images=[image], return_tensors="pt", ).to(self.vlm_model.device) with torch.no_grad(): outputs = self.vlm_model.generate( **inputs, max_new_tokens=1024, do_sample=False, ) generated = outputs[0][inputs.input_ids.shape[1]:] return self.vlm_processor.decode(generated, skip_special_tokens=True) def _vlm_text_generate(self, prompt: str) -> str: """Generate response from VLM (text only).""" import torch messages = [{"role": "user", "content": prompt}] text = self.vlm_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = self.vlm_processor( text=[text], return_tensors="pt", ).to(self.vlm_model.device) with torch.no_grad(): outputs = self.vlm_model.generate( **inputs, max_new_tokens=1024, do_sample=False, ) generated = outputs[0][inputs.input_ids.shape[1]:] return self.vlm_processor.decode(generated, skip_special_tokens=True) def evaluate_soft_tifa(self, image: Image.Image, prompt: str) -> SoftTIFAResult: """Run Soft-TIFA evaluation with atomic prompt decomposition.""" # Step 1: Decompose prompt into primitives decomposition_prompt = f'''Decompose this text-to-image prompt into atomic visual primitives. Prompt: "{prompt}" For each primitive, identify: - content: The specific visual element (e.g., "a red car", "sunset sky") - type: One of [object, attribute, count, relation, action, scene, style] - importance: How critical (0.5-1.0) Example for "A cat sitting on a red chair": [ {{"content": "cat", "type": "object", "importance": 1.0}}, {{"content": "chair", "type": "object", "importance": 0.9}}, {{"content": "red chair", "type": "attribute", "importance": 0.8}}, {{"content": "cat sitting on chair", "type": "relation", "importance": 0.9}} ] Return ONLY valid JSON array for the given prompt:''' decomp_response = self._vlm_text_generate(decomposition_prompt) primitives = parse_json_robust(decomp_response, fallback=[]) if not primitives or not isinstance(primitives, list): return SoftTIFAResult( primitives_count=0, atom_score=0.0, prompt_score=0.0, passed=False, primitive_results=[], ) # Step 2: Evaluate each primitive via VQA primitive_results = [] vqa_templates = { "object": "Is there a {content} in this image?", "attribute": "Does the image show {content}?", "count": "Are there {content}?", "relation": "Is it true that {content}?", "action": "Is {content} happening in this image?", "scene": "Does this image depict {content}?", "style": "Is this image in {content} style?", } for prim in primitives[:20]: # Limit to 20 primitives content = prim.get("content", "") ptype = prim.get("type", "object") template = vqa_templates.get(ptype, vqa_templates["object"]) question = template.format(content=content) vqa_prompt = f"""{question} Answer Yes or No with confidence (0-100%). Format: [Yes/No] (confidence: X%) - brief reasoning""" response = self._vlm_generate(image, vqa_prompt) # Parse response answer = "no" confidence = 0.5 reasoning = None response_lower = response.lower().strip() if response_lower.startswith("yes") or "[yes]" in response_lower: answer = "yes" conf_match = re.search(r'confidence[:\s]*(\d+)%?', response_lower) if conf_match: confidence = float(conf_match.group(1)) / 100.0 if "-" in response: parts = response.split("-", 1) if len(parts) > 1: reasoning = parts[1].strip()[:200] # Calculate score score = confidence if answer == "yes" else (1.0 - confidence) primitive_results.append(PrimitiveResult( content=content, type=ptype, question=question, answer=answer, score=score, reasoning=reasoning, )) # Aggregate scores if primitive_results: atom_score = sum(r.score for r in primitive_results) / len(primitive_results) geo_mean = geometric_mean([r.score for r in primitive_results]) prompt_score = 0.7 * atom_score + 0.3 * geo_mean else: atom_score = 0.0 prompt_score = 0.0 return SoftTIFAResult( primitives_count=len(primitive_results), atom_score=atom_score, prompt_score=prompt_score, passed=prompt_score >= 0.7, primitive_results=primitive_results, ) def evaluate_vlm_judge(self, image: Image.Image, prompt: Optional[str]) -> VLMAssessmentResult: """Run VLM-as-Judge holistic assessment.""" prompt_context = f'Original prompt: "{prompt}"' if prompt else "" semantic_field = '"semantic_accuracy": {"score": 8, "reasoning": "matches prompt well"},' if prompt else "" eval_prompt = f"""Evaluate this AI-generated image on multiple dimensions. {prompt_context} Rate each dimension from 1-10: - **Technical Quality**: Sharpness, noise level, color accuracy, resolution - **Aesthetic Appeal**: Composition, color harmony, visual balance, style - **Realism**: Physical plausibility, lighting consistency, proportions {('- **Semantic Accuracy**: How well it matches the prompt' if prompt else '')} - **AI Artifacts**: Detect issues like distorted faces/hands, extra limbs, text errors Example output: {{ "technical_quality": {{"score": 8, "reasoning": "sharp with good colors"}}, "aesthetic_appeal": {{"score": 7, "reasoning": "balanced composition"}}, "realism": {{"score": 6, "reasoning": "slightly off proportions"}}, {semantic_field} "artifacts": {{"detected": ["slightly distorted fingers"], "severity": "minor"}}, "overall": {{"score": 7, "reasoning": "good quality with minor issues"}} }} Now evaluate this image and return ONLY valid JSON:""" response = self._vlm_generate(image, eval_prompt) data = parse_json_robust(response, fallback=None) if data and isinstance(data, dict): try: def get_score(key: str, default: float = 5.0) -> float: val = data.get(key, {}) if isinstance(val, dict): return float(val.get("score", default)) return float(val) if val else default artifacts = data.get("artifacts", {}) if isinstance(artifacts, dict): detected = artifacts.get("detected", []) severity = artifacts.get("severity", "unknown") else: detected = [] severity = "unknown" return VLMAssessmentResult( technical_quality=get_score("technical_quality"), aesthetic_appeal=get_score("aesthetic_appeal"), realism=get_score("realism"), semantic_accuracy=get_score("semantic_accuracy") if prompt else None, artifacts_detected=detected if isinstance(detected, list) else [], artifacts_severity=severity if isinstance(severity, str) else "unknown", overall=get_score("overall"), reasoning=data.get("overall", {}).get("reasoning") if isinstance(data.get("overall"), dict) else None, ) except (KeyError, TypeError, ValueError): pass # Fallback return VLMAssessmentResult( technical_quality=5.0, aesthetic_appeal=5.0, realism=5.0, semantic_accuracy=5.0 if prompt else None, artifacts_detected=[], artifacts_severity="unknown", overall=5.0, ) def evaluate_technical_metrics(self, image: Image.Image, prompt: Optional[str]) -> TechnicalMetricsResult: """Calculate technical quality metrics.""" sharpness = None colorfulness_score = None contrast_score = None clip_score = None try: sharpness = calculate_sharpness(image) except Exception: pass try: colorfulness_score = calculate_colorfulness(image) except Exception: pass try: contrast_score = calculate_contrast(image) except Exception: pass if prompt: clip_score = calculate_clip_score( image, prompt, self.clip_model, self.clip_preprocess, self.clip_tokenizer, self.device ) return TechnicalMetricsResult( clip_score=clip_score, sharpness=sharpness, colorfulness=colorfulness_score, contrast=contrast_score, ) def _calculate_aggregated_score( self, soft_tifa: Optional[SoftTIFAResult], vlm: Optional[VLMAssessmentResult], technical: Optional[TechnicalMetricsResult], has_prompt: bool, ) -> AggregatedScore: """Calculate comprehensive aggregated score.""" # Prompt alignment scores prompt_alignment_scores = [] if soft_tifa: prompt_alignment_scores.append(soft_tifa.prompt_score) if vlm and vlm.semantic_accuracy is not None: prompt_alignment_scores.append(vlm.semantic_accuracy / 10.0) if technical and technical.clip_score is not None: prompt_alignment_scores.append(technical.clip_score) prompt_alignment = sum(prompt_alignment_scores) / len(prompt_alignment_scores) if prompt_alignment_scores else None # Technical quality scores tech_scores = [] if technical: if technical.sharpness is not None: tech_scores.append(technical.sharpness) if technical.contrast is not None: tech_scores.append(technical.contrast) if vlm: tech_scores.append(vlm.technical_quality / 10.0) technical_quality = sum(tech_scores) / len(tech_scores) if tech_scores else None # Aesthetic appeal scores aesthetic_scores = [] if technical and technical.colorfulness is not None: aesthetic_scores.append(technical.colorfulness) if vlm: aesthetic_scores.append(vlm.aesthetic_appeal / 10.0) aesthetic_appeal = sum(aesthetic_scores) / len(aesthetic_scores) if aesthetic_scores else None # Realism realism = vlm.realism / 10.0 if vlm else None # Artifacts artifacts_score = None if vlm: severity_map = {"none": 1.0, "minor": 0.85, "moderate": 0.6, "major": 0.3, "unknown": 0.7} artifacts_score = severity_map.get(vlm.artifacts_severity, 0.7) # Calculate weighted overall score_map = { "prompt_alignment": prompt_alignment, "technical_quality": technical_quality, "aesthetic_appeal": aesthetic_appeal, "realism": realism, "artifacts": artifacts_score, } category_weights = { "prompt_alignment": 0.45 if has_prompt else 0.0, # Highest priority "technical_quality": 0.20, "aesthetic_appeal": 0.15, "realism": 0.10, "artifacts": 0.10, } weighted_sum = 0.0 total_weight = 0.0 for key, score in score_map.items(): if score is not None: weight = category_weights[key] weighted_sum += score * weight total_weight += weight overall = weighted_sum / total_weight if total_weight > 0 else 0.0 # Confidence max_metrics = 5 if has_prompt else 4 available_metrics = sum(1 for s in score_map.values() if s is not None) confidence = available_metrics / max_metrics # Recommendation recommendation = self._generate_recommendation(score_map, overall) # Normalized weights normalized_weights = {k: v / total_weight for k, v in category_weights.items() if score_map.get(k) is not None} return AggregatedScore( overall=round(overall, 3), grade=score_to_grade(overall), passed=overall >= 0.7, confidence=round(confidence, 2), breakdown=ScoreBreakdown( prompt_alignment=round(prompt_alignment, 3) if prompt_alignment is not None else None, technical_quality=round(technical_quality, 3) if technical_quality is not None else None, aesthetic_appeal=round(aesthetic_appeal, 3) if aesthetic_appeal is not None else None, realism=round(realism, 3) if realism is not None else None, artifacts=round(artifacts_score, 3) if artifacts_score is not None else None, ), weights_used=normalized_weights, recommendation=recommendation, ) def _generate_recommendation(self, scores: Dict, overall: float) -> str: """Generate recommendation based on scores.""" weakest = None weakest_score = 1.0 for key, score in scores.items(): if score is not None and score < weakest_score: weakest_score = score weakest = key if overall >= 0.85: return "Excellent quality image. Ready for production use." elif overall >= 0.70: if weakest and weakest_score < 0.7: suggestions = { "prompt_alignment": "Consider regenerating with clearer prompt.", "technical_quality": "Image has quality issues. Try higher resolution.", "aesthetic_appeal": "Composition could be improved.", "realism": "Physical inconsistencies detected.", "artifacts": "AI artifacts present. Consider regeneration.", } return f"Good overall. Improvement: {suggestions.get(weakest, weakest)}" return "Good quality image. Minor improvements possible." elif overall >= 0.50: return f"Moderate quality. Main issue: {weakest.replace('_', ' ') if weakest else 'overall'}." else: return "Low quality. Regeneration strongly recommended." def evaluate( self, image: Image.Image, prompt: Optional[str] = None, include_soft_tifa: bool = True, include_vlm: bool = True, include_technical: bool = True, ) -> ImageEvalResult: """ Evaluate an AI-generated image. Args: image: PIL Image to evaluate prompt: Optional text prompt used to generate the image include_soft_tifa: Run Soft-TIFA evaluation (requires prompt) include_vlm: Run VLM-as-Judge assessment include_technical: Calculate technical metrics Returns: ImageEvalResult with all evaluation components """ start_time = time.time() soft_tifa_result = None vlm_result = None technical_result = None if include_soft_tifa and prompt: soft_tifa_result = self.evaluate_soft_tifa(image, prompt) if include_vlm: vlm_result = self.evaluate_vlm_judge(image, prompt) if include_technical: technical_result = self.evaluate_technical_metrics(image, prompt) aggregated = self._calculate_aggregated_score( soft_tifa=soft_tifa_result, vlm=vlm_result, technical=technical_result, has_prompt=prompt is not None, ) return ImageEvalResult( score=aggregated, soft_tifa=soft_tifa_result, vlm_assessment=vlm_result, technical_metrics=technical_result, evaluation_time=time.time() - start_time, ) def _vlm_generate_multi_image(self, images: List[Image.Image], prompt: str) -> str: """Generate response from VLM with multiple images.""" import torch content = [] for i, img in enumerate(images): content.append({"type": "image", "image": img}) content.append({"type": "text", "text": prompt}) messages = [{"role": "user", "content": content}] text = self.vlm_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = self.vlm_processor( text=[text], images=images, return_tensors="pt", ).to(self.vlm_model.device) with torch.no_grad(): outputs = self.vlm_model.generate( **inputs, max_new_tokens=2048, do_sample=False, ) generated = outputs[0][inputs.input_ids.shape[1]:] return self.vlm_processor.decode(generated, skip_special_tokens=True) def compare_images( self, images: List[Image.Image], prompt: str, ) -> ComparisonResult: """ Compare multiple images (2-4) against a prompt. Args: images: List of 2-4 PIL Images to compare prompt: The text prompt to evaluate against Returns: ComparisonResult with rankings and scores """ start_time = time.time() num_images = len(images) if num_images < 2 or num_images > 4: raise ValueError("Must provide 2-4 images for comparison") # Step 1: Get individual scores for each image individual_results = [] for img in images: result = self.evaluate(img, prompt, include_soft_tifa=True, include_vlm=True, include_technical=True) individual_results.append(result) individual_scores = [r.score for r in individual_results] # Step 2: Direct multi-image comparison via VLM image_labels = ", ".join([f"Image {i+1}" for i in range(num_images)]) comparison_prompt = f'''You are comparing {num_images} AI-generated images for the prompt: "{prompt}" The images are labeled {image_labels} (in order from left to right). Evaluate and rank ALL images for each criterion. Return a JSON object with rankings (1=best, {num_images}=worst) and scores (0-10). Criteria: 1. prompt_alignment: How well does each image match the prompt? 2. technical_quality: Sharpness, clarity, no artifacts or distortions 3. aesthetic_appeal: Composition, color harmony, visual appeal 4. realism: Physical plausibility, lighting, proportions Example output format: {{ "prompt_alignment": {{"ranking": [2, 1, 3], "scores": [9, 8, 6], "reasoning": "Image 2 captures all elements..."}}, "technical_quality": {{"ranking": [1, 2, 3], "scores": [8, 7, 5], "reasoning": "Image 1 is sharpest..."}}, "aesthetic_appeal": {{"ranking": [2, 1, 3], "scores": [9, 8, 6], "reasoning": "Image 2 has best composition..."}}, "realism": {{"ranking": [1, 2, 3], "scores": [8, 7, 5], "reasoning": "Image 1 looks most natural..."}}, "overall": {{"ranking": [2, 1, 3], "scores": [8.5, 7.5, 5.5], "winner": 2, "reasoning": "Image 2 best balances all criteria..."}} }} Return ONLY valid JSON:''' response = self._vlm_generate_multi_image(images, comparison_prompt) data = parse_json_robust(response, fallback=None) # Parse comparison results rankings_by_criterion = {} overall_scores = [r.score.overall for r in individual_results] # Default: determine winner from individual scores winner_index = max(range(num_images), key=lambda i: overall_scores[i]) # Create ranking from individual scores (sorted indices, highest first) sorted_indices = sorted(range(num_images), key=lambda i: overall_scores[i], reverse=True) overall_ranking = [i + 1 for i in sorted_indices] # Convert to 1-indexed winner_reasoning = f"Image {winner_index + 1} achieved the highest overall score ({overall_scores[winner_index]:.3f})." # Build rankings_by_criterion from individual score breakdowns criteria_map = { "prompt_alignment": lambda s: s.breakdown.prompt_alignment, "technical_quality": lambda s: s.breakdown.technical_quality, "aesthetic_appeal": lambda s: s.breakdown.aesthetic_appeal, "realism": lambda s: s.breakdown.realism, } for criterion, getter in criteria_map.items(): crit_scores = [] for score in individual_scores: val = getter(score) crit_scores.append(val if val is not None else 0.5) # Create ranking for this criterion sorted_idx = sorted(range(num_images), key=lambda i: crit_scores[i], reverse=True) ranking = [i + 1 for i in sorted_idx] rankings_by_criterion[criterion] = ComparisonRanking( criterion=criterion, ranking=ranking, scores=crit_scores, reasoning=None, ) # Try to use VLM comparison if available (overrides individual-based rankings) if data and isinstance(data, dict): criteria = ["prompt_alignment", "technical_quality", "aesthetic_appeal", "realism"] for criterion in criteria: crit_data = data.get(criterion, {}) if isinstance(crit_data, dict): ranking = crit_data.get("ranking", []) scores = crit_data.get("scores", []) reasoning = crit_data.get("reasoning", "") if ranking and scores and len(ranking) == num_images and len(scores) == num_images: # Normalize scores to 0-1 norm_scores = [s / 10.0 for s in scores] rankings_by_criterion[criterion] = ComparisonRanking( criterion=criterion, ranking=ranking, scores=norm_scores, reasoning=reasoning, ) # Overall from VLM overall_data = data.get("overall", {}) if isinstance(overall_data, dict): vlm_ranking = overall_data.get("ranking", []) raw_scores = overall_data.get("scores", []) vlm_winner = overall_data.get("winner") vlm_reasoning = overall_data.get("reasoning", "") if vlm_ranking and len(vlm_ranking) == num_images: overall_ranking = vlm_ranking if raw_scores and len(raw_scores) == num_images: overall_scores = [s / 10.0 if s > 1 else s for s in raw_scores] if vlm_winner and 1 <= vlm_winner <= num_images: winner_index = vlm_winner - 1 if vlm_reasoning: winner_reasoning = vlm_reasoning return ComparisonResult( num_images=num_images, prompt=prompt, overall_ranking=overall_ranking, overall_scores=overall_scores, rankings_by_criterion=rankings_by_criterion, winner_index=max(0, min(winner_index, num_images - 1)), winner_reasoning=winner_reasoning, individual_scores=individual_scores, individual_results=individual_results, evaluation_time=time.time() - start_time, ) class EditEvaluator: """ Image Editing Evaluator Evaluates instruction-based image editing using: - Instruction Following: Were the requested edits applied? - Preservation: Were non-edited regions maintained? - Edit Quality: Is the edit seamless and high-quality? """ def __init__(self, device: str = "cuda"): """Initialize evaluator with models.""" import torch from transformers import AutoModelForImageTextToText, AutoProcessor import lpips self.device = device if torch.cuda.is_available() else "cpu" # Load Qwen2.5-VL-7B-Instruct model_name = "Qwen/Qwen2.5-VL-7B-Instruct" self.vlm_model = AutoModelForImageTextToText.from_pretrained( model_name, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, ) self.vlm_processor = AutoProcessor.from_pretrained( model_name, trust_remote_code=True, ) # Load LPIPS self.lpips_model = lpips.LPIPS(net='alex').to(self.device) def _vlm_generate(self, image: Image.Image, prompt: str) -> str: """Generate response from VLM with image.""" import torch messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], } ] text = self.vlm_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = self.vlm_processor( text=[text], images=[image], return_tensors="pt", ).to(self.vlm_model.device) with torch.no_grad(): outputs = self.vlm_model.generate( **inputs, max_new_tokens=1024, do_sample=False, ) generated = outputs[0][inputs.input_ids.shape[1]:] return self.vlm_processor.decode(generated, skip_special_tokens=True) def _vlm_text_generate(self, prompt: str) -> str: """Generate response from VLM (text only).""" import torch messages = [{"role": "user", "content": prompt}] text = self.vlm_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = self.vlm_processor( text=[text], return_tensors="pt", ).to(self.vlm_model.device) with torch.no_grad(): outputs = self.vlm_model.generate( **inputs, max_new_tokens=1024, do_sample=False, ) generated = outputs[0][inputs.input_ids.shape[1]:] return self.vlm_processor.decode(generated, skip_special_tokens=True) def evaluate_instruction_following(self, edited_image: Image.Image, instruction: str) -> InstructionFollowingResult: """Evaluate if editing instruction was followed.""" decomp_prompt = f'''Analyze this image editing instruction and decompose into atomic edits. Instruction: "{instruction}" Example for "Change the sky to sunset and add a bird": {{ "edits": [ {{"content": "change sky color to sunset", "type": "modify", "target": "sky", "expected_result": "orange/purple sunset sky"}}, {{"content": "add a bird", "type": "add", "target": "sky area", "expected_result": "visible bird in the scene"}} ] }} Return ONLY valid JSON for the given instruction:''' decomp_response = self._vlm_text_generate(decomp_prompt) data = parse_json_robust(decomp_response, fallback={}) edits = data.get("edits", []) if isinstance(data, dict) else [] if not edits or not isinstance(edits, list): # Fallback: evaluate holistically verify_prompt = f'''Evaluate if this image correctly shows the result of the edit: Edit instruction: "{instruction}" Rate success from 0-10. Format: Score: X/10 - Reasoning''' response = self._vlm_generate(edited_image, verify_prompt) score_match = re.search(r'[Ss]core[:\s]*(\d+(?:\.\d+)?)\s*/\s*10', response) score = float(score_match.group(1)) if score_match else 5.0 return InstructionFollowingResult( edit_primitives=[{"content": instruction, "type": "unknown"}], primitive_scores=[{"edit": instruction, "score": score}], overall_score=score / 10.0, reasoning=response[:200] if response else None, ) # Evaluate each edit primitive_scores = [] for edit in edits[:10]: content = edit.get("content", "") target = edit.get("target", "the image") expected = edit.get("expected_result", content) verify_prompt = f'''Verify if this edit was applied: Edit: {content} Target: {target} Expected: {expected} Rate from 0-10. Format: Score: X/10 - Reasoning''' response = self._vlm_generate(edited_image, verify_prompt) score_match = re.search(r'[Ss]core[:\s]*(\d+(?:\.\d+)?)\s*/\s*10', response) score = float(score_match.group(1)) if score_match else 5.0 primitive_scores.append({ "edit": content, "score": score, "reasoning": response[:100] if response else None, }) overall = sum(p["score"] for p in primitive_scores) / len(primitive_scores) if primitive_scores else 0 return InstructionFollowingResult( edit_primitives=edits[:10], primitive_scores=primitive_scores, overall_score=overall / 10.0, ) def evaluate_preservation(self, source_image: Image.Image, edited_image: Image.Image) -> PreservationResult: """Evaluate if non-edited regions were preserved.""" scores = [] # LPIPS lpips_score = calculate_lpips(source_image, edited_image, self.lpips_model, self.device) lpips_similarity = max(0, 1 - lpips_score) if lpips_score is not None else None if lpips_similarity is not None: scores.append(lpips_similarity) # SSIM ssim_score = None try: ssim_score = calculate_ssim(source_image, edited_image) scores.append(ssim_score) except Exception: pass # PSNR psnr_score = None try: psnr_score = calculate_psnr(source_image, edited_image) scores.append(psnr_score) except Exception: pass # Combined score if scores: if lpips_similarity is not None and len(scores) > 1: preservation_score = lpips_similarity * 0.5 + sum(s for s in scores if s != lpips_similarity) / (len(scores) - 1) * 0.5 else: preservation_score = sum(scores) / len(scores) else: preservation_score = 0.5 return PreservationResult( lpips_score=lpips_score, ssim_score=ssim_score, psnr_score=psnr_score, overall_score=preservation_score, ) def evaluate_edit_quality(self, edited_image: Image.Image, instruction: str) -> EditQualityResult: """Evaluate the quality of the edit.""" eval_prompt = f'''Evaluate the quality of this edited image. Edit instruction: "{instruction}" Rate each dimension 1-10: - **Technical**: Seamless blending? Resolution consistent? No visible edit boundaries? - **Aesthetic**: Natural looking? Color harmony maintained? Visually pleasing? - **Coherence**: Physically plausible? Lighting/shadows consistent? Proper perspective? - **Artifacts**: List any issues (blur, color bleeding, unnatural edges, etc.) Example output: {{ "technical": {{"score": 8}}, "aesthetic": {{"score": 7}}, "coherence": {{"score": 8}}, "artifacts": {{"detected": ["slight blur at edge"], "severity": "minor"}} }} Return ONLY valid JSON:''' response = self._vlm_generate(edited_image, eval_prompt) data = parse_json_robust(response, fallback=None) if data and isinstance(data, dict): try: def get_score(key: str, default: float = 5.0) -> float: val = data.get(key, {}) if isinstance(val, dict): return float(val.get("score", default)) return float(val) if val else default technical = get_score("technical") aesthetic = get_score("aesthetic") coherence = get_score("coherence") artifacts_data = data.get("artifacts", {}) if isinstance(artifacts_data, dict): artifacts = artifacts_data.get("detected", []) severity = artifacts_data.get("severity", "unknown") else: artifacts = [] severity = "unknown" overall = (technical + aesthetic + coherence) / 30.0 severity_penalties = {"major": 0.7, "moderate": 0.85, "minor": 0.95, "none": 1.0} overall *= severity_penalties.get(severity, 0.9) return EditQualityResult( technical_score=technical, aesthetic_score=aesthetic, coherence_score=coherence, artifacts=artifacts if isinstance(artifacts, list) else [], artifact_severity=severity if isinstance(severity, str) else "unknown", overall_score=overall, ) except (KeyError, TypeError, ValueError): pass return EditQualityResult( technical_score=5.0, aesthetic_score=5.0, coherence_score=5.0, artifacts=[], artifact_severity="unknown", overall_score=0.5, ) def _calculate_edit_aggregated_score( self, instruction_result: InstructionFollowingResult, preservation_result: PreservationResult, quality_result: EditQualityResult, ) -> EditAggregatedScore: """Calculate comprehensive aggregated score for editing.""" weights = { "instruction_following": 0.35, "preservation": 0.25, "edit_quality": 0.25, "artifacts": 0.15, } instruction_score = instruction_result.overall_score preservation_score = preservation_result.overall_score edit_quality_score = quality_result.overall_score severity_map = {"none": 1.0, "minor": 0.85, "moderate": 0.6, "major": 0.3, "unknown": 0.7} artifacts_score = severity_map.get(quality_result.artifact_severity, 0.7) overall = ( instruction_score * weights["instruction_following"] + preservation_score * weights["preservation"] + edit_quality_score * weights["edit_quality"] + artifacts_score * weights["artifacts"] ) num_primitives = len(instruction_result.primitive_scores) confidence = min(1.0, 0.5 + (num_primitives * 0.1)) recommendation = self._generate_edit_recommendation( instruction_score, preservation_score, edit_quality_score, artifacts_score, overall ) return EditAggregatedScore( overall=round(overall, 3), grade=score_to_grade(overall), passed=overall >= 0.7, confidence=round(confidence, 2), breakdown=EditScoreBreakdown( instruction_following=round(instruction_score, 3), preservation=round(preservation_score, 3), edit_quality=round(edit_quality_score, 3), artifacts=round(artifacts_score, 3), ), weights_used=weights, recommendation=recommendation, ) def _generate_edit_recommendation( self, instruction: float, preservation: float, quality: float, artifacts: float, overall: float, ) -> str: """Generate recommendation for edit quality.""" issues = [] if instruction < 0.6: issues.append("instruction not fully followed") if preservation < 0.6: issues.append("too much content changed") if quality < 0.6: issues.append("edit quality issues") if artifacts < 0.7: issues.append("visible artifacts") if overall >= 0.85: return "Excellent edit. Ready for use." elif overall >= 0.70: if issues: return f"Good edit with minor issues: {', '.join(issues[:2])}." return "Good quality edit. Minor improvements possible." elif overall >= 0.50: if issues: return f"Moderate quality. Issues: {', '.join(issues)}." return "Moderate quality. Consider regenerating." else: return f"Low quality. Issues: {', '.join(issues) if issues else 'multiple problems'}." def evaluate( self, source_image: Image.Image, edited_image: Image.Image, instruction: str, ) -> EditEvalResult: """ Evaluate an image editing result. Args: source_image: Original image before editing edited_image: Image after editing instruction: The editing instruction that was applied Returns: EditEvalResult with all evaluation components """ start_time = time.time() instruction_result = self.evaluate_instruction_following(edited_image, instruction) preservation_result = self.evaluate_preservation(source_image, edited_image) quality_result = self.evaluate_edit_quality(edited_image, instruction) aggregated = self._calculate_edit_aggregated_score( instruction_result=instruction_result, preservation_result=preservation_result, quality_result=quality_result, ) return EditEvalResult( score=aggregated, instruction_following=instruction_result, preservation=preservation_result, edit_quality=quality_result, evaluation_time=time.time() - start_time, )