"""VLM critic prompt and score normalization for agentic T2I upsampling.""" from __future__ import annotations import json from typing import Any from agentic_upsampling.constants import STRICT_OVERALL_THRESHOLD, STRICT_PROMPT_THRESHOLD from agentic_upsampling.data import PromptItem from agentic_upsampling.prompt_upsampler import extract_json_object CATEGORY_SECTIONS = { "text_commercial_ui": ( "Text/commercial/UI/logo checks: readable text for logos, labels, posters, " "billboards, product packaging, or UI. Verify exact quoted strings, spelling, legibility, typography, " "placement, layout, and whether commercial/UI intent is visually clear." ), "people_anatomy": ( "People/anatomy checks: if humans, human-like characters, body parts, portraits, or poses are present or " "required by the prompt, inspect faces, eyes, hands, fingers, limbs, pose, proportions, expression, " "clothing coherence, and physically possible interactions." ), "fantasy_cartoon_vector": ( "Fantasy/cartoon/vector/pixel-art checks: if a stylized medium is requested, judge whether stylization is " "intentional and clean. Penalize messy geometry, inconsistent line language, broken vector shapes, muddy " "palettes, and unwanted photorealistic texture." ), "photorealistic_physical": ( "Photorealistic/physical checks: if realism, physical objects, geometry, camera behavior, reflections, " "transparent materials, shadows, perspective, scale, or contact matter, judge material realism, lighting " "physics, lens plausibility, and whether objects obey real-world physical constraints." ), "general_scene": ( "General scene checks: always judge object completeness, layout clarity, subject relationships, background " "coherence, visual appeal, and absence of obvious AI artifacts." ), } SCORE_KEYS = ( "prompt_adherence_score", "visual_quality_score", "aesthetics_score", "physical_plausibility_score", "category_score", "overall_score", ) ISSUE_SEVERITIES = {"minor", "moderate", "severe"} def all_category_check_text() -> str: """Return the full non-classifying category checklist.""" return "\n".join(f"- {text}" for text in CATEGORY_SECTIONS.values()) def build_judge_prompt(item: PromptItem) -> str: """Build the VLM critic prompt using the original user prompt as task context.""" return f"""You are an expert image quality analyst specializing in AI-generated image evaluation. Your job is to produce an exhaustive defect report. Be meticulous: go beyond obvious problems and look carefully for subtle or background issues too. The attached image was generated by an AI image model. Analyze this image carefully and list every quality issue you observe. For each issue give an approximate location and name the specific object or region involved. Report each distinct occurrence separately. Before finalizing, check these areas, but only report issues you actually see: - Physics: gravity violations, impossible collisions, implausible trajectories. - Object deformation: morphing, melting, stretching of solid objects. - Anatomy: distorted hands, faces, fingers, limbs, or wrong body proportions. - Lighting and shadows: missing shadows or inconsistent illumination. - Depth and scale: wrong spatial relationships, perspective issues, or scale inconsistencies. - Text and numbers: garbled, floating, or incorrect text and digits. - Visual quality: blur patches, noise, compression blocking, visual artifacts, or low-resolution regions. - Color: inconsistent coloration, bleeding, or banding. - Action correctness: prompted actions are correctly displayed. - Prompt following: missing subjects, wrong objects, wrong setting, or wrong action. Depending on the prompt, also apply the relevant checks below: {all_category_check_text()} The attached image was generated from this prompt: {item.prompt} Return exactly one JSON object, no markdown fences and no prose outside JSON: {{ "prompt_adherence_score": , "visual_quality_score": , "aesthetics_score": , "physical_plausibility_score": , "category_score": , "text_rendering_score": , "photorealism_score": , "overall_score": , "issues": [ {{ "category": "", "description": "", "severity": "minor" | "moderate" | "severe" }} ], "prompt_elements": {{ "": "present" | "absent" | "partial" }}, "category_findings": {{"": ""}}, "improvement_directives": [""], "rationale": "<2-4 concise sentences>" }} """ def parse_analysis_response(text: str) -> dict[str, Any]: """Parse and normalize a raw VLM scoring response.""" return normalize_analysis(extract_json_object(text)) def normalize_analysis(data: dict[str, Any]) -> dict[str, Any]: """Normalize VLM analysis into the schema used by selection and reporting.""" normalized = dict(data) for key in SCORE_KEYS: normalized[key] = _score(normalized.get(key)) for optional_key in ("text_rendering_score", "photorealism_score"): if normalized.get(optional_key) is not None: normalized[optional_key] = _score(normalized.get(optional_key)) normalized["issues"] = _normalize_issues(normalized.get("issues")) directives = normalized.get("improvement_directives") if isinstance(directives, list): normalized["improvement_directives"] = [str(item) for item in directives if str(item).strip()] else: normalized["improvement_directives"] = [] findings = normalized.get("category_findings") normalized["category_findings"] = findings if isinstance(findings, dict) else {} normalized["threshold_cleared"] = clears_strict_threshold(normalized) return normalized def clears_strict_threshold(analysis: dict[str, Any]) -> bool: """Return whether a candidate clears the strict quality milestone.""" if _score(analysis.get("overall_score")) < STRICT_OVERALL_THRESHOLD: return False if _score(analysis.get("prompt_adherence_score")) < STRICT_PROMPT_THRESHOLD: return False if _has_severe_issue(analysis.get("issues")): return False if analysis.get("text_rendering_score") is not None: return _score(analysis.get("text_rendering_score")) >= STRICT_PROMPT_THRESHOLD return True def candidate_sort_key(candidate: dict[str, Any]) -> tuple[float, float, float, float, float, int]: """Sort key for picking the best candidate.""" analysis = candidate.get("analysis", {}) iteration = int(candidate.get("iteration", 0)) return ( _score(analysis.get("overall_score")), _score(analysis.get("prompt_adherence_score")), _score(analysis.get("category_score")), _score(analysis.get("visual_quality_score")), _score(analysis.get("aesthetics_score")), -iteration, ) def compact_analysis_for_rewrite(analysis: dict[str, Any]) -> dict[str, Any]: """Return the VLM fields most useful for the next prompt rewrite.""" keys = ( "overall_score", "prompt_adherence_score", "visual_quality_score", "aesthetics_score", "physical_plausibility_score", "category_score", "text_rendering_score", "photorealism_score", "issues", "prompt_elements", "category_findings", "improvement_directives", "rationale", ) return {key: analysis.get(key) for key in keys if key in analysis} def analysis_json_text(data: dict[str, Any]) -> str: """Serialize compact analysis for prompt inclusion.""" return json.dumps(data, ensure_ascii=True, indent=2) def _score(value: Any) -> float: if value is None: return 0.0 try: number = float(value) except (TypeError, ValueError): return 0.0 return max(0.0, min(10.0, number)) def _normalize_issues(value: Any) -> list[dict[str, str]]: if not isinstance(value, list): return [] issues: list[dict[str, str]] = [] for item in value: if not isinstance(item, dict): continue description = str(item.get("description") or "").strip() if not description: continue category = str(item.get("category") or "unspecified").strip() or "unspecified" severity = str(item.get("severity") or "moderate").strip().lower() if severity not in ISSUE_SEVERITIES: severity = "moderate" issues.append({"category": category, "description": description, "severity": severity}) return issues def _has_severe_issue(issues: Any) -> bool: return any(isinstance(item, dict) and item.get("severity") == "severe" for item in issues or [])