ForgeCaptions / app.py
JS6969's picture
Update app.py
6120301 verified
raw
history blame
33.9 kB
import os, io, csv, time, json, base64, re
from typing import List, Tuple, Dict, Any
# ---------------------------------------------------------------------
# Caching
# ---------------------------------------------------------------------
os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
import gradio as gr
from PIL import Image
import torch
from transformers import LlavaForConditionalGeneration, AutoProcessor
# ── HF Spaces GPU decorator (no-op on CPU/local) ─────────────────────
try:
import spaces
gpu = spaces.GPU()
except Exception: # local/CPU
def gpu(f): return f
APP_DIR = os.getcwd()
SESSION_FILE = "/tmp/session.json"
SETTINGS_FILE = "/tmp/cf_settings.json"
JOURNAL_FILE = "/tmp/cf_journal.json"
THUMB_CACHE = os.path.expanduser("~/.cache/forgecaptions/thumbs")
EXCEL_THUMB_DIR = "/tmp/forge_excel_thumbs"
os.makedirs(THUMB_CACHE, exist_ok=True)
os.makedirs(EXCEL_THUMB_DIR, exist_ok=True)
# ---------------------------------------------------------------------
# Model identifiers
# ---------------------------------------------------------------------
MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
# Load the processor on CPU (safe in stateless env)
processor = AutoProcessor.from_pretrained(MODEL_PATH)
# Lazy GPU/CPU model (created inside GPU worker only)
_MODEL = None
_DEVICE = "cpu"
_DTYPE = torch.float32
def get_model():
"""Create/reuse model; only call this from inside @gpu functions."""
global _MODEL, _DEVICE, _DTYPE
if _MODEL is None:
if torch.cuda.is_available():
_DEVICE = "cuda"
_DTYPE = torch.bfloat16
_MODEL = LlavaForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=_DTYPE,
low_cpu_mem_usage=True,
device_map=0, # GPU:0 (inside GPU worker process)
)
else:
_DEVICE = "cpu"
_DTYPE = torch.float32
_MODEL = LlavaForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=_DTYPE,
low_cpu_mem_usage=True,
device_map="cpu",
)
_MODEL.eval()
print(f"[ForgeCaptions] Model ready on {_DEVICE} dtype={_DTYPE}")
return _MODEL, _DEVICE, _DTYPE
print(f"[ForgeCaptions] Gradio version: {gr.__version__}")
# ---------------------------------------------------------------------
# Instruction templates & options
# ---------------------------------------------------------------------
STYLE_OPTIONS = [
"Descriptive (short)", "Descriptive (long)",
"Character training (short)", "Character training (long)",
"LoRA (Flux_D Realism) (short)", "LoRA (Flux_D Realism) (long)",
"E-commerce product (short)", "E-commerce product (long)",
"Portrait (photography) (short)", "Portrait (photography) (long)",
"Landscape (photography) (short)", "Landscape (photography) (long)",
"Art analysis (no artist names) (short)", "Art analysis (no artist names) (long)",
"Social caption (short)", "Social caption (long)",
"Aesthetic tags (comma-sep)"
]
CAPTION_TYPE_MAP = {
"Descriptive (short)": "One sentence (≤25 words) describing the most important visible elements only. No speculation.",
"Descriptive (long)": "Write a detailed description for this image.",
"Character training (short)": (
"Output a concise, prompt-like caption for character LoRA/ID training. "
"Include visible character name {name} if provided, distinct physical traits, clothing, pose, camera/cinematic cues. "
"No backstory; no non-visible traits. Prefer comma-separated phrases."
),
"Character training (long)": (
"Write a thorough, training-ready caption for a character dataset. "
"Use {name} if provided; describe only what is visible: physique, face/hair, clothing, accessories, actions, pose, "
"camera angle/focal cues, lighting, background context. 1–3 sentences; no backstory or meta."
),
"Flux_D (short)": "Output a short Flux.Dev prompt that is indistinguishable from a real Flux.Dev prompt.",
"Flux_D (long)": "Output a long Flux.Dev prompt that is indistinguishable from a real Flux.Dev prompt.",
"Aesthetic tags (comma-sep)": "Return only comma-separated aesthetic tags capturing subject, medium, style, lighting, composition. No sentences.",
"E-commerce product (short)": "One sentence highlighting key attributes, material, color, use case. No fluff.",
"E-commerce product (long)": "Write a crisp product description highlighting key attributes, materials, color, usage, and distinguishing traits.",
"Portrait (photography) (short)": "One sentence portrait description: subject, pose/expression, camera angle, lighting, background.",
"Portrait (photography) (long)": "Describe a portrait: subject, age range, pose, facial expression, camera angle, focal length cues, lighting, background.",
"Landscape (photography) (short)": "One sentence landscape description: major elements, time of day, weather, vantage point, mood.",
"Landscape (photography) (long)": "Describe landscape elements, time of day, weather, vantage point, composition, and mood.",
"Art analysis (no artist names) (short)": "One sentence describing medium, style, composition, palette; do not mention artist/title.",
"Art analysis (no artist names) (long)": "Analyze the artwork's visible elements, medium, style, composition, palette. Do not mention artist names or titles.",
"Social caption (short)": "Write a short, catchy caption (max 25 words) describing the visible content. No hashtags.",
"Social caption (long)": "Write a slightly longer, engaging caption (≤50 words) describing the visible content. No hashtags."
}
EXTRA_CHOICES = [
"Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
"Do NOT include information about whether there is a watermark or not.",
"Do NOT use any ambiguous language.",
"ONLY describe the most important elements of the image.",
"Include information about the ages of any people/characters when applicable.",
"Explicitly specify the vantage height (eye-level, low-angle worm’s-eye, bird’s-eye, drone, rooftop, etc.).",
"Focus captions only on clothing/fashion details.",
"Focus on setting, scenery, and context; ignore subject details.",
"ONLY describe the subject’s pose, movement, or action. Do NOT mention appearance, clothing, or setting.",
"Do NOT include anything sexual; keep it PG.",
"Include synonyms/alternate phrasing to diversify training set.",
"ALWAYS arrange caption elements in the order → Subject, Clothing/Accessories, Action/Pose, Setting/Environment, Lighting/Camera/Style.",
"Do NOT mention the image's resolution.",
"Include information about depth, lighting, and camera angle.",
"Include information on composition (rule of thirds, symmetry, leading lines, etc).",
"Specify the depth of field and whether the background is in focus or blurred.",
"If applicable, mention the likely use of artificial or natural lighting sources.",
"Identify the image orientation (portrait, landscape, or square) if obvious.",
]
NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}."
# ---------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------
def ensure_thumb(path: str, max_side=256) -> str:
try:
im = Image.open(path).convert("RGB")
except Exception:
return path
w, h = im.size
if max(w, h) > max_side:
s = max_side / max(w, h)
im = im.resize((int(w*s), int(h*s)), Image.LANCZOS)
base = os.path.basename(path)
out_path = os.path.join(THUMB_CACHE, os.path.splitext(base)[0] + f"_thumb_{max_side}.jpg")
try:
im.save(out_path, "JPEG", quality=85, optimize=True)
return out_path
except Exception:
return path
def resize_for_model(im: Image.Image, max_side: int) -> Image.Image:
w, h = im.size
if max(w, h) <= max_side:
return im
s = max_side / max(w, h)
return im.resize((int(w*s), int(h*s)), Image.LANCZOS)
def apply_prefix_suffix(caption: str, trigger_word: str, begin_text: str, end_text: str) -> str:
parts = []
if trigger_word.strip():
parts.append(trigger_word.strip())
if begin_text.strip():
parts.append(begin_text.strip())
parts.append(caption.strip())
if end_text.strip():
parts.append(end_text.strip())
return " ".join([p for p in parts if p])
# Instruction + caption
def final_instruction(style_list: List[str], extra_opts: List[str], name_value: str) -> str:
styles = style_list or ["Descriptive (short)"]
parts = [CAPTION_TYPE_MAP.get(s, "") for s in styles]
core = " ".join(p for p in parts if p).strip()
if extra_opts:
core += " " + " ".join(extra_opts)
if NAME_OPTION in (extra_opts or []):
core = core.replace("{name}", (name_value or "{NAME}").strip())
return core
def logo_b64_img() -> str:
candidates = [
os.path.join(APP_DIR, "forgecaptions-logo.png"),
os.path.join(APP_DIR, "captionforge-logo.png"),
"/home/user/app/forgecaptions-logo.png",
"forgecaptions-logo.png",
"captionforge-logo.png",
]
for p in candidates:
if os.path.exists(p):
with open(p, "rb") as f:
b64 = base64.b64encode(f.read()).decode("ascii")
return f"<img src='data:image/png;base64,{b64}' alt='ForgeCaptions' class='cf-logo'>"
return ""
# ---------------------------------------------------------------------
# Persistence
# ---------------------------------------------------------------------
def save_session(rows: List[dict]):
with open(SESSION_FILE, "w", encoding="utf-8") as f:
json.dump(rows, f, ensure_ascii=False, indent=2)
def load_session() -> List[dict]:
if os.path.exists(SESSION_FILE):
with open(SESSION_FILE, "r", encoding="utf-8") as f:
return json.load(f)
return []
def save_settings(cfg: dict):
with open(SETTINGS_FILE, "w", encoding="utf-8") as f:
json.dump(cfg, f, ensure_ascii=False, indent=2)
def load_settings() -> dict:
if os.path.exists(SETTINGS_FILE):
with open(SETTINGS_FILE, "r", encoding="utf-8") as f:
cfg = json.load(f)
else:
cfg = {}
defaults = {
"dataset_name": "forgecaptions",
"temperature": 0.6,
"top_p": 0.9,
"max_tokens": 256,
"max_side": 896,
"styles": ["Character training (long)"], # ← default changed
"extras": [],
"name": "",
"trigger": "",
"begin": "",
"end": "",
"shape_aliases_enabled": True,
"shape_aliases": [],
"excel_thumb_px": 128,
}
for k, v in defaults.items():
cfg.setdefault(k, v)
legacy_map = {
"Descriptive": "Descriptive (short)",
"LoRA (Flux_D Realism)": "LoRA (Flux_D Realism) (short)",
"Portrait (photography)": "Portrait (photography) (short)",
"Landscape (photography)": "Landscape (photography) (short)",
"Art analysis (no artist names)": "Art analysis (no artist names) (short)",
"E-commerce product": "E-commerce product (short)",
}
styles = cfg.get("styles") or []
migrated = []
for s in styles if isinstance(styles, list) else [styles]:
migrated.append(legacy_map.get(s, s))
migrated = [s for s in migrated if s in STYLE_OPTIONS]
if not migrated:
migrated = ["Descriptive (short)"]
cfg["styles"] = migrated
return cfg
def save_journal(data: dict):
with open(JOURNAL_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def load_journal() -> dict:
if os.path.exists(JOURNAL_FILE):
with open(JOURNAL_FILE, "r", encoding="utf-8") as f:
return json.load(f)
return {}
# ---------------------------------------------------------------------
# Shape Aliases
# ---------------------------------------------------------------------
def _compile_shape_aliases_from_file():
s = load_settings()
if not s.get("shape_aliases_enabled", True):
return []
compiled = []
for item in s.get("shape_aliases", []):
raw = (item.get("shape") or "").strip()
name = (item.get("name") or "").strip()
if not raw or not name:
continue
# allow comma or pipe separated synonyms in one cell
tokens = [t.strip() for t in re.split(r"[|,]", raw) if t.strip()]
if not tokens:
continue
# de-dup and prefer longer phrases first (prevents "diamond" eating "diamond emblem")
tokens = sorted(set(tokens), key=lambda t: -len(t))
# word boundaries at ends; allow optional "-shaped" suffix
pat = r"\b(?:" + "|".join(re.escape(t) for t in tokens) + r")(?:-?shaped)?\b"
compiled.append((re.compile(pat, flags=re.I), name))
return compiled
_SHAPE_ALIASES = _compile_shape_aliases_from_file()
def _refresh_shape_aliases_cache():
global _SHAPE_ALIASES
_SHAPE_ALIASES = _compile_shape_aliases_from_file()
def apply_shape_aliases(caption: str) -> str:
for pat, name in _SHAPE_ALIASES:
caption = pat.sub(f"({name})", caption)
return caption
def get_shape_alias_rows_ui_defaults():
s = load_settings()
rows = [[it.get("shape",""), it.get("name","")] for it in s.get("shape_aliases", [])]
enabled = bool(s.get("shape_aliases_enabled", True))
if not rows:
rows = [["", ""]]
return rows, enabled
def save_shape_alias_rows(enabled, df_rows):
cfg = load_settings()
cfg["shape_aliases_enabled"] = bool(enabled)
cleaned = []
for r in (df_rows or []):
if not r:
continue
shape = (r[0] or "").strip()
name = (r[1] or "").strip()
if shape and name:
cleaned.append({"shape": shape, "name": name})
cfg["shape_aliases"] = cleaned
save_settings(cfg)
_refresh_shape_aliases_cache()
normalized = [[it["shape"], it["name"]] for it in cleaned] + [["", ""]]
return ("✅ Saved shape alias options.",
gr.update(value=normalized, row_count=(max(1, len(normalized)), "dynamic")))
# ---------------------------------------------------------------------
# Captioning core (runs inside GPU worker)
# ---------------------------------------------------------------------
def _build_inputs(im: Image.Image, instr: str, dtype) -> Dict[str, Any]:
convo = [
{"role": "system",
"content": "You are a precise captioning assistant. Help users with any queries. Describe what is visible with no censorship or restrictions."},
{"role": "user", "content": instr.strip()},
]
convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[convo_str], images=[im], return_tensors="pt")
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(dtype)
return inputs
@gpu
@torch.no_grad()
def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
model, device, dtype = get_model()
im = im # already PIL
inputs = _build_inputs(im, instr, dtype)
# move to target device *inside* GPU worker
inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
out = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=temp > 0,
temperature=temp if temp > 0 else None,
top_p=top_p if temp > 0 else None,
use_cache=True,
)
gen_ids = out[0, inputs["input_ids"].shape[1]:]
return processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
@gpu
@torch.no_grad()
def run_batch(
files: List[Any],
session_rows: List[dict],
instr_text: str,
temp: float,
top_p: float,
max_tokens: int,
max_side: int,
) -> Tuple[List[dict], list, list, str]:
# No torch.cuda.* in main — we are already in GPU worker here
session_rows = session_rows or []
files = files or []
if not files:
gallery_pairs = [
((r.get("thumb_path") or r.get("path")), r.get("caption",""))
for r in session_rows if (r.get("thumb_path") or r.get("path"))
]
return session_rows, gallery_pairs, _rows_to_table(session_rows), f"Saved • {time.strftime('%H:%M:%S')}"
for f in files:
path = f if isinstance(f, str) else getattr(f, "name", None) or getattr(f, "path", None)
if not path or not os.path.exists(path):
continue
try:
im = Image.open(path).convert("RGB")
except Exception:
continue
im = resize_for_model(im, max_side)
cap = caption_once(im, instr_text, temp, top_p, max_tokens)
cap = apply_shape_aliases(cap)
s = load_settings()
cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end",""))
filename = os.path.basename(path)
thumb = ensure_thumb(path, 256)
session_rows.append({"filename": filename, "caption": cap, "path": path, "thumb_path": thumb})
save_session(session_rows)
gallery_pairs = [
((r.get("thumb_path") or r.get("path")), r.get("caption",""))
for r in session_rows if (r.get("thumb_path") or r.get("path"))
]
return session_rows, gallery_pairs, _rows_to_table(session_rows), f"Saved • {time.strftime('%H:%M:%S')}"
@gpu
@torch.no_grad()
def caption_single(img: Image.Image, instr: str) -> str:
if img is None:
return "No image provided."
s = load_settings()
im = resize_for_model(img, int(s.get("max_side", 896)))
cap = caption_once(im, instr, s.get("temperature",0.6), s.get("top_p",0.9), s.get("max_tokens",256))
cap = apply_shape_aliases(cap)
cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end",""))
return cap
# tiny warmup so Spaces sees a GPU function at startup
@gpu
@torch.no_grad()
def _gpu_startup_warm():
try:
im = Image.new("RGB", (64, 64), (127,127,127))
_ = caption_once(im, "Warm up.", temp=0.0, top_p=1.0, max_tokens=8)
print("[ForgeCaptions] GPU warmup complete")
except Exception as e:
print("[ForgeCaptions] GPU warmup skipped:", e)
# ---------------------------------------------------------------------
# Export helpers
# ---------------------------------------------------------------------
def _rows_to_table(rows: List[dict]) -> list:
return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])]
def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]:
tbl = table_value or []
new = []
for i, r in enumerate(rows or []):
r = dict(r)
if i < len(tbl) and len(tbl[i]) >= 2:
r["filename"] = str(tbl[i][0]) if tbl[i][0] is not None else r.get("filename","")
r["caption"] = str(tbl[i][1]) if tbl[i][1] is not None else r.get("caption","")
new.append(r)
return new
def export_csv_from_table(table_value: Any) -> str:
data = table_value or []
out = f"/tmp/forgecaptions_{int(time.time())}.csv"
with open(out, "w", newline="", encoding="utf-8") as f:
w = csv.writer(f)
w.writerow(["filename", "caption"])
w.writerows(data)
return out
def _resize_for_excel(path: str, px: int) -> str:
try:
im = Image.open(path).convert("RGB")
except Exception:
return path
w, h = im.size
if max(w, h) > px:
s = px / max(w, h)
im = im.resize((int(w*s), int(h*s)), Image.LANCZOS)
base = os.path.basename(path)
out_path = os.path.join(EXCEL_THUMB_DIR, f"{os.path.splitext(base)[0]}_{px}px.jpg")
try:
im.save(out_path, "JPEG", quality=85, optimize=True)
return out_path
except Exception:
return path
def export_excel_with_thumbs(table_value: Any, session_rows: List[dict], thumb_px: int) -> str:
try:
from openpyxl import Workbook
from openpyxl.drawing.image import Image as XLImage
except Exception as e:
raise RuntimeError("Excel export requires 'openpyxl' in requirements.txt.") from e
caption_by_file = {}
for row in (table_value or []):
if not row:
continue
fn = str(row[0]) if len(row) > 0 else ""
cap = str(row[1]) if len(row) > 1 and row[1] is not None else ""
if fn:
caption_by_file[fn] = cap
wb = Workbook(); ws = wb.active; ws.title = "ForgeCaptions"
ws.append(["image", "filename", "caption"])
ws.column_dimensions["A"].width = 24
ws.column_dimensions["B"].width = 42
ws.column_dimensions["C"].width = 100
row_h = int(int(thumb_px) * 0.75) # px→pt-ish
r_i = 2
for r in (session_rows or []):
fn = r.get("filename","")
cap = caption_by_file.get(fn, r.get("caption",""))
ws.cell(row=r_i, column=2, value=fn)
ws.cell(row=r_i, column=3, value=cap)
img_path = r.get("thumb_path") or r.get("path")
if img_path and os.path.exists(img_path):
try:
resized = _resize_for_excel(img_path, int(thumb_px))
xlimg = XLImage(resized)
ws.add_image(xlimg, f"A{r_i}")
ws.row_dimensions[r_i].height = row_h
except Exception:
pass
r_i += 1
out = f"/tmp/forgecaptions_{int(time.time())}.xlsx"
wb.save(out)
return out
def sync_table_to_session(table_value: Any, session_rows: List[dict]) -> Tuple[List[dict], list, str]:
session_rows = _table_to_rows(table_value, session_rows or [])
save_session(session_rows)
gallery_pairs = [
((r.get("thumb_path") or r.get("path")), r.get("caption",""))
for r in session_rows if (r.get("thumb_path") or r.get("path"))
]
return session_rows, gallery_pairs, f"Saved • {time.strftime('%H:%M:%S')}"
# ---------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------
BASE_CSS = """
:root{--galleryW:50%;--tableW:50%;}
.gradio-container{max-width:100%!important}
.cf-hero{display:flex; align-items:center; justify-content:center; gap:16px;
margin:4px 0 12px; text-align:center;}
.cf-hero > div { text-align:center; }
.cf-logo{height:calc(3.25rem + 3 * 1.1rem + 18px);width:auto;object-fit:contain}
.cf-title{margin:0;font-size:3.25rem;line-height:1;letter-spacing:.2px}
.cf-sub{margin:6px 0 0;font-size:1.1rem;color:#cfd3da}
.cf-scroll{max-height:70vh; overflow-y:auto; border:1px solid #e6e6e6; border-radius:10px; padding:8px}
#cfGal .grid > div { height: 96px; }
"""
with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
# ensure Spaces sees a GPU function at start (without touching CUDA in main)
demo.load(_gpu_startup_warm, inputs=None, outputs=None)
settings = load_settings()
settings["styles"] = [s for s in settings.get("styles", []) if s in STYLE_OPTIONS] or ["Character training (long)"]
gr.HTML(value=f"""
<div class="cf-hero">
{logo_b64_img()}
<div>
<h1 class="cf-title">ForgeCaptions</h1>
<div class="cf-sub">Batch captioning</div>
<div class="cf-sub">Scrollable editor & autosave</div>
<div class="cf-sub">CSV / Excel export</div>
</div>
</div>
<hr>""")
# ── Controls
with gr.Group():
with gr.Row():
with gr.Column(scale=2):
with gr.Accordion("Caption style (choose one or combine)", open=True):
style_checks = gr.CheckboxGroup(
choices=STYLE_OPTIONS,
value=settings.get("styles", ["Character training (long)"]),
label=None
)
with gr.Accordion("Extra options", open=False):
extra_opts = gr.CheckboxGroup(
choices=[NAME_OPTION] + EXTRA_CHOICES,
value=settings.get("extras", []),
label=None
)
with gr.Accordion("Name & Prefix/Suffix", open=False):
name_input = gr.Textbox(label="Person / Character Name", value=settings.get("name", ""))
trig = gr.Textbox(label="Trigger word", value=settings.get("trigger",""))
add_start = gr.Textbox(label="Add text to start", value=settings.get("begin",""))
add_end = gr.Textbox(label="Add text to end", value=settings.get("end",""))
with gr.Column(scale=1):
with gr.Accordion("Model Instructions", open=False):
instruction_preview = gr.Textbox(label=None, lines=12)
dataset_name = gr.Textbox(label="Dataset name (export title prefix)",
value=settings.get("dataset_name", "forgecaptions"))
max_side = gr.Slider(256, 1024, settings.get("max_side", 896), step=32, label="Max side (resize)")
excel_thumb_px = gr.Slider(64, 256, value=settings.get("excel_thumb_px", 128),
step=8, label="Excel thumbnail size (px)")
gr.Markdown("Generation settings: temperature 0.6 • top-p 0.9 • max tokens 256")
def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms):
instr = final_instruction(styles or ["Character training (long)"], extra or [], name_value)
cfg = load_settings()
cfg.update({
"styles": styles or ["Character training (long)"],
"extras": extra or [],
"name": name_value,
"trigger": trigv, "begin": begv, "end": endv,
"excel_thumb_px": int(excel_px),
"max_side": int(ms),
})
save_settings(cfg)
return instr
for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side]:
comp.change(_refresh_instruction,
inputs=[style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side],
outputs=[instruction_preview])
demo.load(lambda s,e,n: final_instruction(s or ["Character training (long)"], e or [], n),
inputs=[style_checks, extra_opts, name_input], outputs=[instruction_preview])
# ── Shape Aliases (improved)
with gr.Accordion("Shape Aliases", open=False):
gr.Markdown(
"### 🔷 Shape Aliases\n"
"Replace literal **shape tokens** in captions with a preferred **name**.\n\n"
"**How to use:**\n"
"- Left column = a single token **or** comma/pipe-separated synonyms, e.g. `penis, cock | phallic`\n"
"- Right column = replacement name, e.g. `family-emblem`\n\n"
"Matches are case-insensitive, use whole words, and also catch `*-shaped` (e.g., `diamond-shaped`).\n"
"Multi-word phrases are supported."
)
init_rows, init_enabled = get_shape_alias_rows_ui_defaults()
enable_aliases = gr.Checkbox(label="Enable shape alias replacements", value=init_enabled)
alias_table = gr.Dataframe(
headers=["shape (literal token)", "name to insert"],
value=init_rows,
col_count=(2, "fixed"),
row_count=(max(1, len(init_rows)), "dynamic"),
datatype=["str","str"],
type="array",
interactive=True
)
with gr.Row():
add_row_btn = gr.Button("+ Add row", variant="secondary")
clear_btn = gr.Button("Clear", variant="secondary")
save_btn = gr.Button("💾 Save", variant="primary")
save_status = gr.Markdown("")
def _add_row(cur):
cur = (cur or []) + [["", ""]]
return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic"))
def _clear_rows():
return gr.update(value=[["", ""]], row_count=(1, "dynamic"))
add_row_btn.click(_add_row, inputs=[alias_table], outputs=[alias_table])
clear_btn.click(_clear_rows, outputs=[alias_table])
save_btn.click(save_shape_alias_rows, inputs=[enable_aliases, alias_table], outputs=[save_status, alias_table])
# ── Tabs: Single & Batch
with gr.Tabs():
with gr.Tab("Single"):
input_image_single = gr.Image(type="pil", label="Input Image", height=512, width=512)
single_caption_btn = gr.Button("Caption")
single_caption_out = gr.Textbox(label="Caption (single)")
single_caption_btn.click(
caption_single,
inputs=[input_image_single, instruction_preview],
outputs=[single_caption_out]
)
with gr.Tab("Batch"):
with gr.Accordion("Uploaded images", open=True):
input_files = gr.File(label="Drop images", file_types=["image"], file_count="multiple", type="filepath")
run_button = gr.Button("Caption batch", variant="primary")
# ── Results + Table (same position)
rows_state = gr.State(load_session())
autosave_md = gr.Markdown("Ready.")
with gr.Row():
with gr.Column(scale=1):
gallery = gr.Gallery(
label="Results (image + caption)",
show_label=True,
columns=3,
height=520,
elem_id="cfGal",
elem_classes=["cf-scroll"]
)
with gr.Column(scale=1):
table = gr.Dataframe(
label="Editable captions (whole session)",
value=_rows_to_table(load_session()),
headers=["filename", "caption"],
interactive=True,
wrap=True,
elem_id="cfTable",
elem_classes=["cf-scroll"]
)
# Exports
with gr.Row():
with gr.Column():
export_csv_btn = gr.Button("Export CSV")
csv_file = gr.File(label="CSV file", visible=False)
with gr.Column():
export_xlsx_btn = gr.Button("Export Excel (.xlsx) with thumbnails")
xlsx_file = gr.File(label="Excel file", visible=False)
def _initial_gallery(rows):
rows = rows or []
return [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
for r in rows if (r.get("thumb_path") or r.get("path"))]
demo.load(_initial_gallery, inputs=[rows_state], outputs=[gallery])
# Scroll sync
gr.HTML("""
<script>
(function () {
function findGalleryScrollRoot() {
const host = document.querySelector("#cfGal");
if (!host) return null;
return host.querySelector(".grid") || host.querySelector("[data-testid='gallery']") || host;
}
function findTableScrollRoot() {
const host = document.querySelector("#cfTable");
if (!host) return null;
return host.querySelector(".wrap") ||
host.querySelector(".dataframe-wrap") ||
(host.querySelector("table") ? host.querySelector("table").parentElement : null) ||
host;
}
function syncScroll(a, b) {
if (!a || !b) return;
let lock = false;
const onScrollA = () => { if (lock) return; lock = true; b.scrollTop = a.scrollTop; lock = false; };
const onScrollB = () => { if (lock) return; lock = true; a.scrollTop = b.scrollTop; lock = false; };
a.addEventListener("scroll", onScrollA, { passive: true });
b.addEventListener("scroll", onScrollB, { passive: true });
}
let tries = 0;
const timer = setInterval(() => {
tries++;
const gal = findGalleryScrollRoot();
const tab = findTableScrollRoot();
if (gal && tab) {
const H = Math.min(gal.clientHeight || 520, tab.clientHeight || 520);
gal.style.maxHeight = H + "px";
gal.style.overflowY = "auto";
tab.style.maxHeight = H + "px";
tab.style.overflowY = "auto";
syncScroll(gal, tab);
clearInterval(timer);
}
if (tries > 20) clearInterval(timer);
}, 100);
})();
</script>
""")
# Batch run → rows + gallery + table
def _run_click(files, rows, instr, ms):
s = load_settings()
t = s.get("temperature", 0.6)
p = s.get("top_p", 0.9)
m = s.get("max_tokens", 256)
new_rows, gal, tbl, stamp = run_batch(files, rows or [], instr, t, p, m, int(ms))
return new_rows, gal, tbl, stamp
run_button.click(
_run_click,
inputs=[input_files, rows_state, instruction_preview, max_side],
outputs=[rows_state, gallery, table, autosave_md]
)
# Table edits sync
table.change(
sync_table_to_session,
inputs=[table, rows_state],
outputs=[rows_state, gallery, autosave_md]
)
# Exports
export_csv_btn.click(
lambda tbl: (export_csv_from_table(tbl), gr.update(visible=True)),
inputs=[table], outputs=[csv_file, csv_file]
)
export_xlsx_btn.click(
lambda tbl, rows, px: (export_excel_with_thumbs(tbl, rows or [], int(px)), gr.update(visible=True)),
inputs=[table, rows_state, excel_thumb_px], outputs=[xlsx_file, xlsx_file]
)
# Launch (SSR off for stability on Spaces)
if __name__ == "__main__":
demo.queue(max_size=64).launch(
server_name="0.0.0.0",
server_port=int(os.getenv("PORT", "7860")),
ssr_mode=False,
debug=True,
show_error=True,
)