image-understanding / sample_code /generate_images.py
shahkushan1's picture
Add Gradio micro-trend app with LLM integrations and prompt loading
2948ced
from __future__ import annotations
import re
import json
import logging
import mimetypes
from datetime import datetime
import shutil
from pathlib import Path
from typing import List, Dict, Any
import os
from google import genai
from google.genai import types, errors as genai_errors
from constants import (
ROOT,
PLAN_PATH,
DEFAULT_SETTINGS,
GEMINI_SETTINGS_KEYS,
LOG_NAME,
STYLE_VIEW_ORDER,
)
class PromptTask(Dict[str, Any]):
"""Typed mapping representing a single prompt item (slide, filename, prompt, order)."""
slide: str
filename: str
prompt: str
order: int
def slugify(text: str) -> str:
"""Convert a slide label to a filesystem-friendly slug."""
text = text.lower()
text = re.sub(r"[^a-z0-9]+", "-", text)
text = text.strip("-")
return text or "slide"
def output_root(brand: str, collection: str) -> Path:
"""Base directory for images under outputs/<brand>/collection/<collection>/images."""
return ROOT / "outputs" / slugify(brand) / "collection" / slugify(collection) / "images"
def parse_plan(plan_path: Path) -> List[PromptTask]:
"""Pull every FILENAME/PROMPT pair from plan.md, keeping slide context."""
lines = plan_path.read_text(encoding="utf-8").splitlines()
tasks: List[PromptTask] = []
current_slide = "slide"
order = 0
i = 0
while i < len(lines):
line = lines[i].strip()
# Capture slide headers (e.g., "Slide 6", "Slide 6A", "Slides 8–19")
slide_match = re.match(r"slide[s]?\s+([\w–-]+)", line, re.IGNORECASE)
if slide_match:
current_slide = line
i += 1
continue
file_match = re.match(r"FILENAME:\s*(.+)", line, re.IGNORECASE)
if file_match:
filename = file_match.group(1).strip()
# Advance to the PROMPT line
j = i + 1
while j < len(lines) and not lines[j].strip().lower().startswith("prompt:"):
j += 1
if j >= len(lines):
raise ValueError(f"PROMPT missing for {filename}")
prompt_line = lines[j].strip()
prompt = prompt_line.split("PROMPT:", 1)[1].strip()
# Capture any prompt continuation lines until the next FILENAME/Slide header
k = j + 1
continuation: List[str] = []
while k < len(lines):
next_line = lines[k].strip()
if next_line == "":
k += 1
continue
if re.match(r"(FILENAME:|Slide[s]?\s+|< Text Content)", next_line, re.IGNORECASE):
break
continuation.append(next_line)
k += 1
if continuation:
prompt = " ".join([prompt] + continuation)
tasks.append(
{
"slide": current_slide,
"filename": filename,
"prompt": prompt,
"order": order,
}
)
order += 1
i = k
continue
i += 1
return tasks
def setup_logging(out_root: Path, mode: str, level: str = "INFO", log_path: Path | None = None) -> logging.Logger:
"""Configure stdout + file logging; file goes under outputs/<brand>/collection/<collection>/images/<mode>/run.log."""
out_dir = out_root / mode
out_dir.mkdir(parents=True, exist_ok=True)
log_file = log_path or out_dir / "run.log"
numeric_level = getattr(logging, level.upper(), logging.INFO)
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
handlers: list[logging.Handler] = [logging.StreamHandler()]
handlers[0].setFormatter(formatter)
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter)
handlers.append(file_handler)
logging.basicConfig(level=numeric_level, handlers=handlers, force=True)
logger = logging.getLogger(LOG_NAME)
logger.setLevel(numeric_level)
logger.info("Logging initialized (mode=%s, file=%s, level=%s)", mode, log_file, level.upper())
return logger
def clean_output_dir(out_root: Path, mode: str, logger: logging.Logger | None = None) -> None:
"""Remove all files under the given mode folder to start from a clean slate."""
target = out_root / mode
if target.exists():
if logger:
logger.info("Cleaning output directory %s", target)
shutil.rmtree(target)
def anchor_part(prompt: str, logger: logging.Logger) -> tuple[None, None]:
"""Anchor images via folder are removed; function retained for signature compatibility."""
return None, None
def part_from_path(path: Path) -> types.Part:
"""Load an image file as a genai Part with an inferred MIME type."""
mime, _ = mimetypes.guess_type(path)
if not mime:
mime = "image/jpeg"
data = path.read_bytes()
return types.Part.from_bytes(data=data, mime_type=mime)
def detect_style_view(filename: str) -> tuple[str, str] | None:
"""Return (style_code, view) for style view images; else None."""
m = re.match(r"^(MG-[A-Z]-SS\d{2}-\d{3})_(hero|front|back)\.", filename, re.IGNORECASE)
if not m:
return None
style_code, view = m.group(1), m.group(2).lower()
return style_code, view
def reorder_tasks_for_styles(tasks: List[PromptTask]) -> List[PromptTask]:
"""Group style views and order front->back->hero; keep non-style in original positions."""
style_map: dict[str, list[PromptTask]] = {}
for t in tasks:
sv = detect_style_view(t["filename"])
if sv:
code, view = sv
style_map.setdefault(code, []).append(t | {"_style_view": view})
final: list[PromptTask] = []
processed: set[str] = set()
for t in tasks:
sv = detect_style_view(t["filename"])
if not sv:
final.append(t)
continue
code, _ = sv
if code in processed:
continue
processed.add(code)
grouped = style_map.get(code, [])
grouped.sort(key=lambda x: (STYLE_VIEW_ORDER.get(x.get("_style_view", "other"), 99), x["order"]))
# remove helper key before returning
for g in grouped:
g.pop("_style_view", None)
final.append(g)
return final
def load_settings(settings_path: Path | None) -> dict[str, str]:
"""Load settings JSON (if present) limited to known keys."""
path = settings_path or DEFAULT_SETTINGS
if not path.exists():
return {}
try:
data = json.loads(path.read_text(encoding="utf-8"))
return {k: v for k, v in data.items() if k in GEMINI_SETTINGS_KEYS and v}
except json.JSONDecodeError as exc: # noqa: BLE001
raise SystemExit(f"settings file {path} is not valid JSON: {exc}")
def resolve_api_key(settings: dict[str, str]) -> str:
"""Get API key from env first, then settings file; env wins."""
if os.environ.get("GEMINI_API_KEY"):
return os.environ["GEMINI_API_KEY"]
if os.environ.get("GOOGLE_API_KEY"):
return os.environ["GOOGLE_API_KEY"]
key = settings.get("GEMINI_API_KEY") or settings.get("GOOGLE_API_KEY")
if key:
return key
raise SystemExit(
"GEMINI_API_KEY/GOOGLE_API_KEY is not set. Set the env var or create settings.json (see settings.example.json)."
)
def generate_images(
tasks: List[PromptTask],
mode: str,
limit: int | None,
api_key: str,
logger: logging.Logger,
out_root: Path,
timestamp: str,
) -> None:
"""Generate images for the provided tasks list and write a manifest."""
client = genai.Client(api_key=api_key)
to_run = tasks if mode == "full" else tasks[: limit or 2]
logger.info("Starting generation: %s tasks (mode=%s)", len(to_run), mode)
style_state: dict[str, dict[str, Path]] = {}
manifest = []
for task in to_run:
slide_slug = slugify(task["slide"])
out_dir = out_root / mode / slide_slug
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / task["filename"]
logger.info("Generating %s (slide: %s)", task["filename"], task["slide"])
style_view = detect_style_view(task["filename"])
anchor, anchor_code = anchor_part(task["prompt"], logger)
anchor_used = None
contents: list[types.Part | str] = []
if style_view:
style_code, view = style_view
state = style_state.get(style_code, {})
preferred_path: Path | None = None
if view == "hero":
preferred_path = None # first in chain, prompt-only
elif view == "front":
preferred_path = state.get("hero")
anchor_used = "hero" if preferred_path else None
elif view == "back":
preferred_path = state.get("front") or state.get("hero")
anchor_used = "front" if state.get("front") else ("hero" if state.get("hero") else None)
if preferred_path and preferred_path.exists():
try:
contents.append(part_from_path(preferred_path))
anchor_used = anchor_used or "previous"
except Exception as exc: # noqa: BLE001
logger.exception("Failed to load prior view %s as anchor: %s", preferred_path, exc)
if not contents and anchor:
contents.append(anchor)
anchor_used = anchor_used or (f"face:{anchor_code}" if anchor_code else "face")
contents.append(task["prompt"])
try:
response = client.models.generate_content(
model="gemini-2.5-flash-image",
contents=contents,
config=types.GenerateContentConfig(
response_modalities=["image"],
),
)
except genai_errors.ClientError as exc: # noqa: BLE001
if exc.status_code == 401:
logger.error(
"401 Unauthorized. This usually means the key is missing, the wrong key type (use Google AI Studio key), or Vertex mode requires OAuth."
)
logger.exception("Generation failed for %s: %s", task["filename"], exc)
continue
except Exception as exc: # noqa: BLE001
logger.exception("Generation failed for %s: %s", task["filename"], exc)
continue
parts = getattr(response, "parts", None)
if not parts:
logger.warning("Response had no parts for %s; skipping", task["filename"])
continue
image_part = next((p for p in parts if getattr(p, "inline_data", None)), None)
if not image_part:
logger.warning("No image part returned for %s; skipping", task["filename"])
continue
try:
image = image_part.as_image()
image.save(out_path)
logger.info("Saved %s", out_path)
except Exception as exc: # noqa: BLE001
logger.exception("Failed to save %s: %s", out_path, exc)
continue
if style_view:
style_code, view = style_view
style_state.setdefault(style_code, {})[view] = out_path
manifest.append(
{
"slide": task["slide"],
"filename": task["filename"],
"prompt": task["prompt"],
"path": str(out_path.relative_to(ROOT)),
"anchor": anchor_used,
"anchor_face": anchor_code,
}
)
if manifest:
manifest_path = out_root / mode / f"manifest_{timestamp}.json"
manifest_path.write_text(json.dumps(manifest, indent=2), encoding="utf-8")
logger.info("Manifest written to %s", manifest_path)
else:
logger.warning("No images were generated; manifest not written")
def run_generation(
mode: str = "full",
limit: int | None = None,
settings_path: Path | None = None,
brand: str = "mango",
collection: str = "hot-summer-ss26",
log_level: str = "INFO",
clean: bool = False,
) -> None:
"""Programmatic entrypoint to parse plan.md and generate Gemini images."""
tasks = parse_plan(PLAN_PATH)
tasks = reorder_tasks_for_styles(tasks)
if not tasks:
raise SystemExit("No prompts found in plan.md")
if mode == "sample" and limit is not None and limit <= 0:
raise SystemExit("limit must be positive for sample mode")
out_root = output_root(brand, collection)
if clean:
clean_output_dir(out_root, mode)
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
logger = setup_logging(out_root, mode, log_level)
settings = load_settings(settings_path)
if "GOOGLE_GENAI_USE_VERTEXAI" in settings and "GOOGLE_GENAI_USE_VERTEXAI" not in os.environ:
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = str(settings["GOOGLE_GENAI_USE_VERTEXAI"]).lower()
api_key = resolve_api_key(settings)
prompts_dir = ROOT / "outputs" / slugify(brand) / "collection" / slugify(collection) / "prompts"
prompts_dir.mkdir(parents=True, exist_ok=True)
prompts_path = prompts_dir / f"images_prompts_{timestamp}.json"
prompts_payload = [{"slide": t["slide"], "filename": t["filename"], "prompt": t["prompt"], "order": t["order"]} for t in tasks]
prompts_path.write_text(json.dumps(prompts_payload, indent=2), encoding="utf-8")
logger.info("Prompts saved to %s", prompts_path)
generate_images(tasks, mode, limit, api_key, logger, out_root, timestamp)
# ---------------- Reusable runner for external callers ---------------- #
def run_prompt_list(
prompt_items: List[Dict[str, Any]],
brand: str,
collection: str,
mode: str,
api_key: str | None,
logger: logging.Logger,
) -> List[Dict[str, Any]]:
"""
Run a list of prompts (each dict: prompt, filename) through Gemini and save to images/<mode>.
Returns manifest entries.
Includes simple anchoring for style views (hero -> front -> back) using previously
generated images for the same style code.
"""
out_root = output_root(brand, collection) / mode
out_root.mkdir(parents=True, exist_ok=True)
# Auth resolution: prefer explicit api_key, else settings.json (no env reliance)
settings_path = ROOT / "settings.json"
settings = {}
if settings_path.exists():
try:
settings = json.loads(settings_path.read_text(encoding="utf-8"))
except Exception:
settings = {}
use_vertex = str(settings.get("GOOGLE_GENAI_USE_VERTEXAI", "")).lower() == "true"
if not api_key:
api_key = settings.get("GEMINI_API_KEY") or settings.get("GOOGLE_API_KEY")
project = (
settings.get("GOOGLE_VERTEX_PROJECT")
or settings.get("GOOGLE_CLOUD_PROJECT")
or settings.get("GCLOUD_PROJECT")
)
location = settings.get("GOOGLE_VERTEX_LOCATION") or settings.get("GOOGLE_CLOUD_LOCATION") or "us-central1"
logger.info(
"[gemini] auth resolution: api_key=%s use_vertex=%s project=%s location=%s",
"yes" if api_key else "no",
use_vertex,
project or "none",
location,
)
if use_vertex:
if not project:
raise SystemExit(
"Gemini Vertex auth missing project. Set GOOGLE_VERTEX_PROJECT or GOOGLE_CLOUD_PROJECT in settings.json."
)
client = genai.Client(vertexai={"project": project, "location": location})
logger.info("[gemini] using Vertex ADC project=%s location=%s", project, location)
elif api_key:
client = genai.Client(api_key=api_key)
logger.info("[gemini] using API key auth")
else:
raise SystemExit(
"Gemini auth missing: set GEMINI_API_KEY/GOOGLE_API_KEY in settings.json or set GOOGLE_GENAI_USE_VERTEXAI=true with GOOGLE_CLOUD_PROJECT in settings.json"
)
manifest = []
style_state: dict[str, dict[str, Path]] = {}
for item in prompt_items:
prompt = item["prompt"]
filename = item.get("filename") or f"prompt_{len(manifest)+1}.png"
slide_slug = slugify(item.get("slide", "adhoc"))
if item.get("out_path"):
out_path = (ROOT / item["out_path"]).resolve() if not Path(item["out_path"]).is_absolute() else Path(item["out_path"])
out_path.parent.mkdir(parents=True, exist_ok=True)
else:
out_dir = out_root / slide_slug
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / filename
logger.info("[run_prompt_list] %s -> %s", filename, out_path)
contents: list[types.Part | str] = []
# Style chaining: if filename matches style view, attach prior image
anchor_used = None
style_view = detect_style_view(filename)
if style_view:
code, view = style_view
state = style_state.get(code, {})
preferred_path: Path | None = None
if view == "hero":
preferred_path = None
elif view == "front":
preferred_path = state.get("hero")
anchor_used = "hero" if preferred_path else None
elif view == "back":
preferred_path = state.get("front") or state.get("hero")
anchor_used = "front" if state.get("front") else ("hero" if state.get("hero") else None)
if preferred_path and preferred_path.exists():
try:
contents.append(part_from_path(preferred_path))
anchor_used = anchor_used or "previous"
except Exception as exc: # noqa: BLE001
logger.exception("Failed to load prior view %s as anchor: %s", preferred_path, exc)
contents.append(prompt)
try:
resp = client.models.generate_content(
model="gemini-2.5-flash-image",
contents=contents,
config=types.GenerateContentConfig(response_modalities=["image"]),
)
image_part = None
if hasattr(resp, "parts") and resp.parts:
image_part = next((p for p in resp.parts if getattr(p, "inline_data", None)), None)
if not image_part and hasattr(resp, "candidates"):
for cand in resp.candidates or []:
content = getattr(cand, "content", None)
parts = getattr(content, "parts", []) if content else []
for part in parts or []:
if getattr(part, "inline_data", None):
image_part = part
break
if image_part:
break
if not image_part:
logger.warning("[run_prompt_list] no image returned for %s", filename)
manifest.append({"filename": filename, "status": "no_image"})
continue
image = image_part.as_image()
image.save(out_path)
manifest.append({"filename": filename, "status": "ok", "path": str(out_path.relative_to(ROOT)), "anchor": anchor_used})
if style_view:
code, view = style_view
style_state.setdefault(code, {})[view] = out_path
except Exception as exc: # noqa: BLE001
logger.exception("[run_prompt_list] failed for %s: %s", filename, exc)
manifest.append({"filename": filename, "status": f"error:{exc}"})
return manifest
def run_prompt_list_vertex_chain(
prompt_items: List[Dict[str, Any]],
brand: str,
collection: str,
mode: str,
logger: logging.Logger,
temp: float = 1.0,
top_p: float = 0.95,
) -> List[Dict[str, Any]]:
"""
Multi-turn Vertex image chain per style (hero → front → back) with image feedback.
End-to-end flow:
1) Group prompts by style/slide so each style runs as one mini-session.
2) HERO: call Gemini Vertex with the hero prompt (no anchors). Save the returned image.
3) FRONT: send the original hero prompt as a user turn, the hero image as a *model* turn,
then the front prompt as a user turn. Generate and save the front image.
4) BACK: send hero prompt + hero image (model turn) + front prompt + front image (model turn),
then the back prompt. Generate and save the back image.
5) Persist outputs under `outputs/<brand>/collection/<collection>/images/<mode>/...`
and record a manifest entry per view.
Notes:
- Uses Vertex client with explicit safety + image config (1:1, 1K) and temperature/top_p controls.
- If any of the three views fail, the function logs an error for that view and continues to the next style.
"""
from collections import defaultdict
out_root = output_root(brand, collection) / mode
out_root.mkdir(parents=True, exist_ok=True)
client = genai.Client(vertexai=True)
cfg = types.GenerateContentConfig(
temperature=temp,
top_p=top_p,
max_output_tokens=32768,
response_modalities=["TEXT", "IMAGE"],
safety_settings=[
types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
],
image_config=types.ImageConfig(
aspect_ratio="1:1",
image_size="1K",
output_mime_type="image/png",
),
)
def to_model_image_content(img_bytes: bytes) -> types.Content:
"""Wrap prior image bytes as a model-role content part for chaining."""
return types.Content(
role="model",
parts=[
types.Part.from_text(text="`"),
types.Part.from_bytes(data=img_bytes, mime_type="image/png"),
],
)
def extract_first_image(resp) -> bytes | None:
"""Extract the first inline image payload from a Vertex response object."""
for cand in getattr(resp, "candidates", []) or []:
parts = getattr(getattr(cand, "content", None), "parts", []) or []
for part in parts:
if getattr(part, "inline_data", None) and getattr(part.inline_data, "data", None):
return part.inline_data.data
return None
grouped: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
for itm in prompt_items:
grouped[itm.get("slide") or itm.get("style_name") or "unknown"].append(itm)
manifest: List[Dict[str, Any]] = []
for slide, items in grouped.items():
logger.info("[vertex-chain] style=%s items=%d", slide, len(items))
hero = next((i for i in items if "_hero" in i.get("filename", "")), None)
front = next((i for i in items if "_front" in i.get("filename", "")), None)
back = next((i for i in items if "_back" in i.get("filename", "")), None)
if not (hero and front and back):
logger.warning("[vertex-chain] skip %s missing hero/front/back", slide)
continue
def resolve_out_path(itm: Dict[str, Any]) -> Path:
"""Resolve the output path for an item, creating parent folders as needed."""
op = itm.get("out_path")
if op:
p = Path(op)
if not p.is_absolute():
p = ROOT / p
p.parent.mkdir(parents=True, exist_ok=True)
return p
out_dir = out_root / (itm.get("slide") or slide)
out_dir.mkdir(parents=True, exist_ok=True)
return out_dir / itm.get("filename", "out.png")
# HERO
# 1) Hero request: single user turn with hero prompt.
hero_resp = client.models.generate_content(
model="gemini-2.5-flash-image",
contents=[types.Content(role="user", parts=[types.Part.from_text(text=hero["prompt"])])],
config=cfg,
)
hero_img = extract_first_image(hero_resp)
if not hero_img:
manifest.append({"filename": hero.get("filename"), "status": "error", "path": None})
logger.error("[vertex-chain] no hero image for %s", slide)
continue
hero_path = resolve_out_path(hero)
hero_path.write_bytes(hero_img)
manifest.append({"filename": hero.get("filename"), "status": "ok", "path": str(hero_path.relative_to(ROOT))})
# FRONT
# 2) Front request: feed hero prompt (user) + hero image (model turn) + front prompt (user).
contents_front = [
types.Content(role="user", parts=[types.Part.from_text(text=hero["prompt"])]),
to_model_image_content(hero_img),
types.Content(role="user", parts=[types.Part.from_text(text=front["prompt"])]),
]
front_resp = client.models.generate_content(
model="gemini-2.5-flash-image",
contents=contents_front,
config=cfg,
)
front_img = extract_first_image(front_resp)
if not front_img:
manifest.append({"filename": front.get("filename"), "status": "error", "path": None})
logger.error("[vertex-chain] no front image for %s", slide)
continue
front_path = resolve_out_path(front)
front_path.write_bytes(front_img)
manifest.append({"filename": front.get("filename"), "status": "ok", "path": str(front_path.relative_to(ROOT))})
# BACK
# 3) Back request: hero prompt (user) + hero image (model) + front prompt (user) + front image (model) + back prompt (user).
contents_back = [
types.Content(role="user", parts=[types.Part.from_text(text=hero["prompt"])]),
to_model_image_content(hero_img),
types.Content(role="user", parts=[types.Part.from_text(text=front["prompt"])]),
to_model_image_content(front_img),
types.Content(role="user", parts=[types.Part.from_text(text=back["prompt"])]),
]
back_resp = client.models.generate_content(
model="gemini-2.5-flash-image",
contents=contents_back,
config=cfg,
)
back_img = extract_first_image(back_resp)
if not back_img:
manifest.append({"filename": back.get("filename"), "status": "error", "path": None})
logger.error("[vertex-chain] no back image for %s", slide)
continue
back_path = resolve_out_path(back)
back_path.write_bytes(back_img)
manifest.append({"filename": back.get("filename"), "status": "ok", "path": str(back_path.relative_to(ROOT))})
return manifest