mingyuliutw's picture
Super-squash branch 'main' using huggingface_hub
fdafd05
"""VLM critic prompt and score normalization for agentic T2I upsampling."""
from __future__ import annotations
import json
from typing import Any
from agentic_upsampling.constants import STRICT_OVERALL_THRESHOLD, STRICT_PROMPT_THRESHOLD
from agentic_upsampling.data import PromptItem
from agentic_upsampling.prompt_upsampler import extract_json_object
CATEGORY_SECTIONS = {
"text_commercial_ui": (
"Text/commercial/UI/logo checks: readable text for logos, labels, posters, "
"billboards, product packaging, or UI. Verify exact quoted strings, spelling, legibility, typography, "
"placement, layout, and whether commercial/UI intent is visually clear."
),
"people_anatomy": (
"People/anatomy checks: if humans, human-like characters, body parts, portraits, or poses are present or "
"required by the prompt, inspect faces, eyes, hands, fingers, limbs, pose, proportions, expression, "
"clothing coherence, and physically possible interactions."
),
"fantasy_cartoon_vector": (
"Fantasy/cartoon/vector/pixel-art checks: if a stylized medium is requested, judge whether stylization is "
"intentional and clean. Penalize messy geometry, inconsistent line language, broken vector shapes, muddy "
"palettes, and unwanted photorealistic texture."
),
"photorealistic_physical": (
"Photorealistic/physical checks: if realism, physical objects, geometry, camera behavior, reflections, "
"transparent materials, shadows, perspective, scale, or contact matter, judge material realism, lighting "
"physics, lens plausibility, and whether objects obey real-world physical constraints."
),
"general_scene": (
"General scene checks: always judge object completeness, layout clarity, subject relationships, background "
"coherence, visual appeal, and absence of obvious AI artifacts."
),
}
SCORE_KEYS = (
"prompt_adherence_score",
"visual_quality_score",
"aesthetics_score",
"physical_plausibility_score",
"category_score",
"overall_score",
)
ISSUE_SEVERITIES = {"minor", "moderate", "severe"}
def all_category_check_text() -> str:
"""Return the full non-classifying category checklist."""
return "\n".join(f"- {text}" for text in CATEGORY_SECTIONS.values())
def build_judge_prompt(item: PromptItem) -> str:
"""Build the VLM critic prompt using the original user prompt as task context."""
return f"""You are an expert image quality analyst specializing in AI-generated image evaluation.
Your job is to produce an exhaustive defect report. Be meticulous: go beyond obvious problems and look carefully for subtle or background issues too.
The attached image was generated by an AI image model.
Analyze this image carefully and list every quality issue you observe.
For each issue give an approximate location and name the specific object or region involved. Report each distinct occurrence separately.
Before finalizing, check these areas, but only report issues you actually see:
- Physics: gravity violations, impossible collisions, implausible trajectories.
- Object deformation: morphing, melting, stretching of solid objects.
- Anatomy: distorted hands, faces, fingers, limbs, or wrong body proportions.
- Lighting and shadows: missing shadows or inconsistent illumination.
- Depth and scale: wrong spatial relationships, perspective issues, or scale inconsistencies.
- Text and numbers: garbled, floating, or incorrect text and digits.
- Visual quality: blur patches, noise, compression blocking, visual artifacts, or low-resolution regions.
- Color: inconsistent coloration, bleeding, or banding.
- Action correctness: prompted actions are correctly displayed.
- Prompt following: missing subjects, wrong objects, wrong setting, or wrong action.
Depending on the prompt, also apply the relevant checks below:
{all_category_check_text()}
The attached image was generated from this prompt:
{item.prompt}
Return exactly one JSON object, no markdown fences and no prose outside JSON:
{{
"prompt_adherence_score": <number 0-10>,
"visual_quality_score": <number 0-10>,
"aesthetics_score": <number 0-10>,
"physical_plausibility_score": <number 0-10>,
"category_score": <number 0-10>,
"text_rendering_score": <number 0-10 or null>,
"photorealism_score": <number 0-10 or null>,
"overall_score": <number 0-10>,
"issues": [
{{
"category": "<concise label>",
"description": "<what failed and where in the image>",
"severity": "minor" | "moderate" | "severe"
}}
],
"prompt_elements": {{
"<key noun or action from the prompt>": "present" | "absent" | "partial"
}},
"category_findings": {{"<check area>": "<concise finding>"}},
"improvement_directives": ["<specific prompt rewrite instruction>"],
"rationale": "<2-4 concise sentences>"
}}
"""
def parse_analysis_response(text: str) -> dict[str, Any]:
"""Parse and normalize a raw VLM scoring response."""
return normalize_analysis(extract_json_object(text))
def normalize_analysis(data: dict[str, Any]) -> dict[str, Any]:
"""Normalize VLM analysis into the schema used by selection and reporting."""
normalized = dict(data)
for key in SCORE_KEYS:
normalized[key] = _score(normalized.get(key))
for optional_key in ("text_rendering_score", "photorealism_score"):
if normalized.get(optional_key) is not None:
normalized[optional_key] = _score(normalized.get(optional_key))
normalized["issues"] = _normalize_issues(normalized.get("issues"))
directives = normalized.get("improvement_directives")
if isinstance(directives, list):
normalized["improvement_directives"] = [str(item) for item in directives if str(item).strip()]
else:
normalized["improvement_directives"] = []
findings = normalized.get("category_findings")
normalized["category_findings"] = findings if isinstance(findings, dict) else {}
normalized["threshold_cleared"] = clears_strict_threshold(normalized)
return normalized
def clears_strict_threshold(analysis: dict[str, Any]) -> bool:
"""Return whether a candidate clears the strict quality milestone."""
if _score(analysis.get("overall_score")) < STRICT_OVERALL_THRESHOLD:
return False
if _score(analysis.get("prompt_adherence_score")) < STRICT_PROMPT_THRESHOLD:
return False
if _has_severe_issue(analysis.get("issues")):
return False
if analysis.get("text_rendering_score") is not None:
return _score(analysis.get("text_rendering_score")) >= STRICT_PROMPT_THRESHOLD
return True
def candidate_sort_key(candidate: dict[str, Any]) -> tuple[float, float, float, float, float, int]:
"""Sort key for picking the best candidate."""
analysis = candidate.get("analysis", {})
iteration = int(candidate.get("iteration", 0))
return (
_score(analysis.get("overall_score")),
_score(analysis.get("prompt_adherence_score")),
_score(analysis.get("category_score")),
_score(analysis.get("visual_quality_score")),
_score(analysis.get("aesthetics_score")),
-iteration,
)
def compact_analysis_for_rewrite(analysis: dict[str, Any]) -> dict[str, Any]:
"""Return the VLM fields most useful for the next prompt rewrite."""
keys = (
"overall_score",
"prompt_adherence_score",
"visual_quality_score",
"aesthetics_score",
"physical_plausibility_score",
"category_score",
"text_rendering_score",
"photorealism_score",
"issues",
"prompt_elements",
"category_findings",
"improvement_directives",
"rationale",
)
return {key: analysis.get(key) for key in keys if key in analysis}
def analysis_json_text(data: dict[str, Any]) -> str:
"""Serialize compact analysis for prompt inclusion."""
return json.dumps(data, ensure_ascii=True, indent=2)
def _score(value: Any) -> float:
if value is None:
return 0.0
try:
number = float(value)
except (TypeError, ValueError):
return 0.0
return max(0.0, min(10.0, number))
def _normalize_issues(value: Any) -> list[dict[str, str]]:
if not isinstance(value, list):
return []
issues: list[dict[str, str]] = []
for item in value:
if not isinstance(item, dict):
continue
description = str(item.get("description") or "").strip()
if not description:
continue
category = str(item.get("category") or "unspecified").strip() or "unspecified"
severity = str(item.get("severity") or "moderate").strip().lower()
if severity not in ISSUE_SEVERITIES:
severity = "moderate"
issues.append({"category": category, "description": description, "severity": severity})
return issues
def _has_severe_issue(issues: Any) -> bool:
return any(isinstance(item, dict) and item.get("severity") == "severe" for item in issues or [])