image-understanding / schemas.py
shahkushan1's picture
Add Gradio micro-trend app with LLM integrations and prompt loading
2948ced
"""Lightweight schema helpers for micro-trend JSON validation and summarization."""
from __future__ import annotations
from typing import Any, Dict, List
REQUIRED_TOP_LEVEL_KEYS = {"meta", "global_scene", "garments", "image_level_micro_trends"}
class ValidationError(Exception):
pass
def validate_trend_payload(payload: Any) -> Dict[str, Any]:
"""Basic structural validation to ensure expected keys/types exist."""
if not isinstance(payload, dict):
raise ValidationError("Payload is not a JSON object")
missing = REQUIRED_TOP_LEVEL_KEYS - payload.keys()
if missing:
raise ValidationError(f"Missing top-level keys: {', '.join(sorted(missing))}")
if not isinstance(payload.get("garments"), list):
raise ValidationError("`garments` must be a list")
for i, garment in enumerate(payload["garments"]):
if not isinstance(garment, dict):
raise ValidationError(f"garments[{i}] is not an object")
if "category" not in garment:
raise ValidationError(f"garments[{i}] missing `category`")
if "print_overview" in garment and not isinstance(garment["print_overview"], dict):
raise ValidationError(f"garments[{i}].print_overview must be an object")
if "print_placement" in garment and not isinstance(garment["print_placement"], list):
raise ValidationError(f"garments[{i}].print_placement must be a list")
return payload # type: ignore[return-value]
def _fmt_list(vals: List[str]) -> str:
vals = [v for v in vals if v]
if not vals:
return ""
if len(vals) == 1:
return vals[0]
return ", ".join(vals[:-1]) + f" and {vals[-1]}"
def _summarize_placement(placements: List[Dict[str, Any]]) -> str:
if not placements:
return "placement not specified"
parts = []
for p in placements[:3]:
zone = p.get("zone") or "zone unknown"
side = p.get("side") or "side n/a"
coverage = p.get("coverage_percent_of_zone")
orientation = p.get("orientation")
note = p.get("notes")
chunk = f"{zone} ({side}"
if coverage is not None:
chunk += f", ~{coverage}% coverage"
if orientation:
chunk += f", {orientation.lower()} orientation"
chunk += ")"
if note:
chunk += f" [{note}]"
parts.append(chunk)
if len(placements) > 3:
parts.append("additional placements not shown")
return "; ".join(parts)
def _summarize_motifs(motifs: List[Dict[str, Any]]) -> str:
if not motifs:
return "motifs not specified"
parts = []
for m in motifs[:3]:
motif = m.get("motif_type") or "motif"
desc = m.get("motif_description")
scale = m.get("scale")
density = m.get("density")
spacing = m.get("spacing_pattern")
colors = m.get("colorways")
chunk = motif
if desc:
chunk += f" ({desc})"
details = _fmt_list([scale, density, spacing])
if details:
chunk += f" | {details}"
if colors:
chunk += f" | colors: {colors}"
parts.append(chunk)
if len(motifs) > 3:
parts.append("additional motif atoms not shown")
return "; ".join(parts)
def build_summary(payload: Dict[str, Any], max_garments: int = 3) -> List[str]:
"""Derive structured bullet points (Markdown-friendly) that narrate the JSON contents."""
bullets: List[str] = []
meta = payload.get("meta") or {}
scene = payload.get("global_scene") or {}
meta_bits = _fmt_list(
[
f"image quality {meta.get('image_quality')}" if meta.get("image_quality") else "",
meta.get("image_type"),
meta.get("view_type"),
f"{meta.get('num_visible_garments')} garment(s)" if meta.get("num_visible_garments") is not None else "",
]
)
scene_bits = _fmt_list(
[
scene.get("setting"),
"model present" if scene.get("model_present") else "",
f"occlusions: {scene.get('occlusions_or_crops')}" if scene.get("occlusions_or_crops") else "",
]
)
bullets.append(f"**Scene:** {meta_bits or 'n/a'}; {scene_bits or 'setting n/a'}.")
garments: List[Dict[str, Any]] = payload.get("garments", [])[:max_garments]
for idx, g in enumerate(garments, start=1):
cat = g.get("category") or g.get("sub_category") or "garment"
role = g.get("role") or "primary"
base_color = g.get("base_color_main") or "color n/a"
secondary = _fmt_list(g.get("base_color_secondary") or [])
fabric = g.get("base_fabric_impression")
presence = g.get("print_presence")
overview = g.get("print_overview") or {}
primary_family = overview.get("primary_print_family")
secondary_families = _fmt_list(overview.get("secondary_print_families") or [])
style_tags = _fmt_list(overview.get("print_style_tags") or [])
technique = overview.get("print_technique_estimate")
placement = _summarize_placement(g.get("print_placement") or [])
motifs = _summarize_motifs(g.get("motif_atoms") or [])
color_story = g.get("color_story") or {}
contrast = color_story.get("contrast_behavior")
print_colors = _fmt_list(color_story.get("print_colors") or [])
text_logo = g.get("text_and_logo_details") or {}
has_text = text_logo.get("has_text_or_logo")
text_samples = _fmt_list(text_logo.get("text_samples") or [])
tags = g.get("micro_trend_inferences") or {}
trend_tags = _fmt_list(
(tags.get("print_micro_trend_tags") or [])
+ (tags.get("placement_micro_trend_tags") or [])
+ (tags.get("color_micro_trend_tags") or [])
+ (tags.get("other_detail_micro_trend_tags") or [])
)
confidence = g.get("confidence") or {}
bullet = (
f"**Garment {idx} ({role}) — {cat}:** base color {base_color}"
f"{' with ' + secondary if secondary else ''}"
f"{' | fabric ' + fabric if fabric else ''}"
f"; print presence {presence or 'n/a'}"
)
if primary_family:
bullet += f"; primary print family {primary_family}"
if secondary_families:
bullet += f"; secondary {secondary_families}"
if style_tags:
bullet += f"; style {style_tags}"
if technique:
bullet += f"; technique {technique}"
bullet += f"; placement: {placement}"
bullet += f"; motifs: {motifs}"
if print_colors or contrast:
bullet += f"; colors: ground={color_story.get('ground_color') or 'n/a'}, print={print_colors or 'n/a'}, contrast={contrast or 'n/a'}"
if has_text:
placements = _fmt_list(text_logo.get("placement") or [])
style = text_logo.get("style")
bullet += f"; text/logo present ({placements or 'placement n/a'}, style {style or 'n/a'}, samples: {text_samples or 'n/a'})"
if trend_tags:
bullet += f"; micro-trend tags: {trend_tags}"
if confidence.get("overall"):
bullet += f"; confidence overall {confidence.get('overall')}"
bullets.append(bullet + ".")
tags = (payload.get("image_level_micro_trends") or {}).get("deduplicated_tags") or []
if isinstance(tags, list) and tags:
bullets.append("**Image-level micro-trend tags:** " + ", ".join(tags) + ".")
summary_comment = (payload.get("image_level_micro_trends") or {}).get("summary_comment")
if isinstance(summary_comment, str) and summary_comment.strip():
bullets.append("**Image-level summary:** " + summary_comment.strip())
return bullets