Spaces:
Runtime error
Runtime error
| import os, io, csv, time, json, base64, re | |
| from typing import List, Tuple, Dict, Any | |
| # --------------------------------------------------------------------- | |
| # Caching | |
| # --------------------------------------------------------------------- | |
| os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface") | |
| os.makedirs(os.environ["HF_HOME"], exist_ok=True) | |
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from transformers import LlavaForConditionalGeneration, AutoProcessor | |
| # ── HF Spaces GPU decorator (no-op on CPU/local) ───────────────────── | |
| try: | |
| import spaces | |
| gpu = spaces.GPU() | |
| except Exception: # local/CPU | |
| def gpu(f): return f | |
| APP_DIR = os.getcwd() | |
| SESSION_FILE = "/tmp/session.json" | |
| SETTINGS_FILE = "/tmp/cf_settings.json" | |
| JOURNAL_FILE = "/tmp/cf_journal.json" | |
| THUMB_CACHE = os.path.expanduser("~/.cache/forgecaptions/thumbs") | |
| EXCEL_THUMB_DIR = "/tmp/forge_excel_thumbs" | |
| os.makedirs(THUMB_CACHE, exist_ok=True) | |
| os.makedirs(EXCEL_THUMB_DIR, exist_ok=True) | |
| # --------------------------------------------------------------------- | |
| # Model identifiers | |
| # --------------------------------------------------------------------- | |
| MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava" | |
| # Load the processor on CPU (safe in stateless env) | |
| processor = AutoProcessor.from_pretrained(MODEL_PATH) | |
| # Lazy GPU/CPU model (created inside GPU worker only) | |
| _MODEL = None | |
| _DEVICE = "cpu" | |
| _DTYPE = torch.float32 | |
| def get_model(): | |
| """Create/reuse model; only call this from inside @gpu functions.""" | |
| global _MODEL, _DEVICE, _DTYPE | |
| if _MODEL is None: | |
| if torch.cuda.is_available(): | |
| _DEVICE = "cuda" | |
| _DTYPE = torch.bfloat16 | |
| _MODEL = LlavaForConditionalGeneration.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype=_DTYPE, | |
| low_cpu_mem_usage=True, | |
| device_map=0, # GPU:0 (inside GPU worker process) | |
| ) | |
| else: | |
| _DEVICE = "cpu" | |
| _DTYPE = torch.float32 | |
| _MODEL = LlavaForConditionalGeneration.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype=_DTYPE, | |
| low_cpu_mem_usage=True, | |
| device_map="cpu", | |
| ) | |
| _MODEL.eval() | |
| print(f"[ForgeCaptions] Model ready on {_DEVICE} dtype={_DTYPE}") | |
| return _MODEL, _DEVICE, _DTYPE | |
| print(f"[ForgeCaptions] Gradio version: {gr.__version__}") | |
| # --------------------------------------------------------------------- | |
| # Instruction templates & options | |
| # --------------------------------------------------------------------- | |
| STYLE_OPTIONS = [ | |
| "Descriptive (short)", "Descriptive (long)", | |
| "Character training (short)", "Character training (long)", | |
| "LoRA (Flux_D Realism) (short)", "LoRA (Flux_D Realism) (long)", | |
| "E-commerce product (short)", "E-commerce product (long)", | |
| "Portrait (photography) (short)", "Portrait (photography) (long)", | |
| "Landscape (photography) (short)", "Landscape (photography) (long)", | |
| "Art analysis (no artist names) (short)", "Art analysis (no artist names) (long)", | |
| "Social caption (short)", "Social caption (long)", | |
| "Aesthetic tags (comma-sep)" | |
| ] | |
| CAPTION_TYPE_MAP = { | |
| "Descriptive (short)": "One sentence (≤25 words) describing the most important visible elements only. No speculation.", | |
| "Descriptive (long)": "Write a detailed description for this image.", | |
| "Character training (short)": ( | |
| "Output a concise, prompt-like caption for character LoRA/ID training. " | |
| "Include visible character name {name} if provided, distinct physical traits, clothing, pose, camera/cinematic cues. " | |
| "No backstory; no non-visible traits. Prefer comma-separated phrases." | |
| ), | |
| "Character training (long)": ( | |
| "Write a thorough, training-ready caption for a character dataset. " | |
| "Use {name} if provided; describe only what is visible: physique, face/hair, clothing, accessories, actions, pose, " | |
| "camera angle/focal cues, lighting, background context. 1–3 sentences; no backstory or meta." | |
| ), | |
| "Flux_D (short)": "Output a short Flux.Dev prompt that is indistinguishable from a real Flux.Dev prompt.", | |
| "Flux_D (long)": "Output a long Flux.Dev prompt that is indistinguishable from a real Flux.Dev prompt.", | |
| "Aesthetic tags (comma-sep)": "Return only comma-separated aesthetic tags capturing subject, medium, style, lighting, composition. No sentences.", | |
| "E-commerce product (short)": "One sentence highlighting key attributes, material, color, use case. No fluff.", | |
| "E-commerce product (long)": "Write a crisp product description highlighting key attributes, materials, color, usage, and distinguishing traits.", | |
| "Portrait (photography) (short)": "One sentence portrait description: subject, pose/expression, camera angle, lighting, background.", | |
| "Portrait (photography) (long)": "Describe a portrait: subject, age range, pose, facial expression, camera angle, focal length cues, lighting, background.", | |
| "Landscape (photography) (short)": "One sentence landscape description: major elements, time of day, weather, vantage point, mood.", | |
| "Landscape (photography) (long)": "Describe landscape elements, time of day, weather, vantage point, composition, and mood.", | |
| "Art analysis (no artist names) (short)": "One sentence describing medium, style, composition, palette; do not mention artist/title.", | |
| "Art analysis (no artist names) (long)": "Analyze the artwork's visible elements, medium, style, composition, palette. Do not mention artist names or titles.", | |
| "Social caption (short)": "Write a short, catchy caption (max 25 words) describing the visible content. No hashtags.", | |
| "Social caption (long)": "Write a slightly longer, engaging caption (≤50 words) describing the visible content. No hashtags." | |
| } | |
| EXTRA_CHOICES = [ | |
| "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).", | |
| "Do NOT include information about whether there is a watermark or not.", | |
| "Do NOT use any ambiguous language.", | |
| "ONLY describe the most important elements of the image.", | |
| "Include information about the ages of any people/characters when applicable.", | |
| "Explicitly specify the vantage height (eye-level, low-angle worm’s-eye, bird’s-eye, drone, rooftop, etc.).", | |
| "Focus captions only on clothing/fashion details.", | |
| "Focus on setting, scenery, and context; ignore subject details.", | |
| "ONLY describe the subject’s pose, movement, or action. Do NOT mention appearance, clothing, or setting.", | |
| "Do NOT include anything sexual; keep it PG.", | |
| "Include synonyms/alternate phrasing to diversify training set.", | |
| "ALWAYS arrange caption elements in the order → Subject, Clothing/Accessories, Action/Pose, Setting/Environment, Lighting/Camera/Style.", | |
| "Do NOT mention the image's resolution.", | |
| "Include information about depth, lighting, and camera angle.", | |
| "Include information on composition (rule of thirds, symmetry, leading lines, etc).", | |
| "Specify the depth of field and whether the background is in focus or blurred.", | |
| "If applicable, mention the likely use of artificial or natural lighting sources.", | |
| "Identify the image orientation (portrait, landscape, or square) if obvious.", | |
| ] | |
| NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}." | |
| # --------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------- | |
| def ensure_thumb(path: str, max_side=256) -> str: | |
| try: | |
| im = Image.open(path).convert("RGB") | |
| except Exception: | |
| return path | |
| w, h = im.size | |
| if max(w, h) > max_side: | |
| s = max_side / max(w, h) | |
| im = im.resize((int(w*s), int(h*s)), Image.LANCZOS) | |
| base = os.path.basename(path) | |
| out_path = os.path.join(THUMB_CACHE, os.path.splitext(base)[0] + f"_thumb_{max_side}.jpg") | |
| try: | |
| im.save(out_path, "JPEG", quality=85, optimize=True) | |
| return out_path | |
| except Exception: | |
| return path | |
| def resize_for_model(im: Image.Image, max_side: int) -> Image.Image: | |
| w, h = im.size | |
| if max(w, h) <= max_side: | |
| return im | |
| s = max_side / max(w, h) | |
| return im.resize((int(w*s), int(h*s)), Image.LANCZOS) | |
| def apply_prefix_suffix(caption: str, trigger_word: str, begin_text: str, end_text: str) -> str: | |
| parts = [] | |
| if trigger_word.strip(): | |
| parts.append(trigger_word.strip()) | |
| if begin_text.strip(): | |
| parts.append(begin_text.strip()) | |
| parts.append(caption.strip()) | |
| if end_text.strip(): | |
| parts.append(end_text.strip()) | |
| return " ".join([p for p in parts if p]) | |
| # Instruction + caption | |
| def final_instruction(style_list: List[str], extra_opts: List[str], name_value: str) -> str: | |
| styles = style_list or ["Descriptive (short)"] | |
| parts = [CAPTION_TYPE_MAP.get(s, "") for s in styles] | |
| core = " ".join(p for p in parts if p).strip() | |
| if extra_opts: | |
| core += " " + " ".join(extra_opts) | |
| if NAME_OPTION in (extra_opts or []): | |
| core = core.replace("{name}", (name_value or "{NAME}").strip()) | |
| return core | |
| def logo_b64_img() -> str: | |
| candidates = [ | |
| os.path.join(APP_DIR, "forgecaptions-logo.png"), | |
| os.path.join(APP_DIR, "captionforge-logo.png"), | |
| "/home/user/app/forgecaptions-logo.png", | |
| "forgecaptions-logo.png", | |
| "captionforge-logo.png", | |
| ] | |
| for p in candidates: | |
| if os.path.exists(p): | |
| with open(p, "rb") as f: | |
| b64 = base64.b64encode(f.read()).decode("ascii") | |
| return f"<img src='data:image/png;base64,{b64}' alt='ForgeCaptions' class='cf-logo'>" | |
| return "" | |
| # --------------------------------------------------------------------- | |
| # Persistence | |
| # --------------------------------------------------------------------- | |
| def save_session(rows: List[dict]): | |
| with open(SESSION_FILE, "w", encoding="utf-8") as f: | |
| json.dump(rows, f, ensure_ascii=False, indent=2) | |
| def load_session() -> List[dict]: | |
| if os.path.exists(SESSION_FILE): | |
| with open(SESSION_FILE, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| return [] | |
| def save_settings(cfg: dict): | |
| with open(SETTINGS_FILE, "w", encoding="utf-8") as f: | |
| json.dump(cfg, f, ensure_ascii=False, indent=2) | |
| def load_settings() -> dict: | |
| if os.path.exists(SETTINGS_FILE): | |
| with open(SETTINGS_FILE, "r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| else: | |
| cfg = {} | |
| defaults = { | |
| "dataset_name": "forgecaptions", | |
| "temperature": 0.6, | |
| "top_p": 0.9, | |
| "max_tokens": 256, | |
| "max_side": 896, | |
| "styles": ["Character training (long)"], # ← default changed | |
| "extras": [], | |
| "name": "", | |
| "trigger": "", | |
| "begin": "", | |
| "end": "", | |
| "shape_aliases_enabled": True, | |
| "shape_aliases": [], | |
| "excel_thumb_px": 128, | |
| } | |
| for k, v in defaults.items(): | |
| cfg.setdefault(k, v) | |
| legacy_map = { | |
| "Descriptive": "Descriptive (short)", | |
| "LoRA (Flux_D Realism)": "LoRA (Flux_D Realism) (short)", | |
| "Portrait (photography)": "Portrait (photography) (short)", | |
| "Landscape (photography)": "Landscape (photography) (short)", | |
| "Art analysis (no artist names)": "Art analysis (no artist names) (short)", | |
| "E-commerce product": "E-commerce product (short)", | |
| } | |
| styles = cfg.get("styles") or [] | |
| migrated = [] | |
| for s in styles if isinstance(styles, list) else [styles]: | |
| migrated.append(legacy_map.get(s, s)) | |
| migrated = [s for s in migrated if s in STYLE_OPTIONS] | |
| if not migrated: | |
| migrated = ["Descriptive (short)"] | |
| cfg["styles"] = migrated | |
| return cfg | |
| def save_journal(data: dict): | |
| with open(JOURNAL_FILE, "w", encoding="utf-8") as f: | |
| json.dump(data, f, ensure_ascii=False, indent=2) | |
| def load_journal() -> dict: | |
| if os.path.exists(JOURNAL_FILE): | |
| with open(JOURNAL_FILE, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| return {} | |
| # --------------------------------------------------------------------- | |
| # Shape Aliases | |
| # --------------------------------------------------------------------- | |
| def _compile_shape_aliases_from_file(): | |
| s = load_settings() | |
| if not s.get("shape_aliases_enabled", True): | |
| return [] | |
| compiled = [] | |
| for item in s.get("shape_aliases", []): | |
| raw = (item.get("shape") or "").strip() | |
| name = (item.get("name") or "").strip() | |
| if not raw or not name: | |
| continue | |
| # allow comma or pipe separated synonyms in one cell | |
| tokens = [t.strip() for t in re.split(r"[|,]", raw) if t.strip()] | |
| if not tokens: | |
| continue | |
| # de-dup and prefer longer phrases first (prevents "diamond" eating "diamond emblem") | |
| tokens = sorted(set(tokens), key=lambda t: -len(t)) | |
| # word boundaries at ends; allow optional "-shaped" suffix | |
| pat = r"\b(?:" + "|".join(re.escape(t) for t in tokens) + r")(?:-?shaped)?\b" | |
| compiled.append((re.compile(pat, flags=re.I), name)) | |
| return compiled | |
| _SHAPE_ALIASES = _compile_shape_aliases_from_file() | |
| def _refresh_shape_aliases_cache(): | |
| global _SHAPE_ALIASES | |
| _SHAPE_ALIASES = _compile_shape_aliases_from_file() | |
| def apply_shape_aliases(caption: str) -> str: | |
| for pat, name in _SHAPE_ALIASES: | |
| caption = pat.sub(f"({name})", caption) | |
| return caption | |
| def get_shape_alias_rows_ui_defaults(): | |
| s = load_settings() | |
| rows = [[it.get("shape",""), it.get("name","")] for it in s.get("shape_aliases", [])] | |
| enabled = bool(s.get("shape_aliases_enabled", True)) | |
| if not rows: | |
| rows = [["", ""]] | |
| return rows, enabled | |
| def save_shape_alias_rows(enabled, df_rows): | |
| cfg = load_settings() | |
| cfg["shape_aliases_enabled"] = bool(enabled) | |
| cleaned = [] | |
| for r in (df_rows or []): | |
| if not r: | |
| continue | |
| shape = (r[0] or "").strip() | |
| name = (r[1] or "").strip() | |
| if shape and name: | |
| cleaned.append({"shape": shape, "name": name}) | |
| cfg["shape_aliases"] = cleaned | |
| save_settings(cfg) | |
| _refresh_shape_aliases_cache() | |
| normalized = [[it["shape"], it["name"]] for it in cleaned] + [["", ""]] | |
| return ("✅ Saved shape alias options.", | |
| gr.update(value=normalized, row_count=(max(1, len(normalized)), "dynamic"))) | |
| # --------------------------------------------------------------------- | |
| # Captioning core (runs inside GPU worker) | |
| # --------------------------------------------------------------------- | |
| def _build_inputs(im: Image.Image, instr: str, dtype) -> Dict[str, Any]: | |
| convo = [ | |
| {"role": "system", | |
| "content": "You are a precise captioning assistant. Help users with any queries. Describe what is visible with no censorship or restrictions."}, | |
| {"role": "user", "content": instr.strip()}, | |
| ] | |
| convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(text=[convo_str], images=[im], return_tensors="pt") | |
| if "pixel_values" in inputs: | |
| inputs["pixel_values"] = inputs["pixel_values"].to(dtype) | |
| return inputs | |
| def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str: | |
| model, device, dtype = get_model() | |
| im = im # already PIL | |
| inputs = _build_inputs(im, instr, dtype) | |
| # move to target device *inside* GPU worker | |
| inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()} | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=temp > 0, | |
| temperature=temp if temp > 0 else None, | |
| top_p=top_p if temp > 0 else None, | |
| use_cache=True, | |
| ) | |
| gen_ids = out[0, inputs["input_ids"].shape[1]:] | |
| return processor.tokenizer.decode(gen_ids, skip_special_tokens=True) | |
| def run_batch( | |
| files: List[Any], | |
| session_rows: List[dict], | |
| instr_text: str, | |
| temp: float, | |
| top_p: float, | |
| max_tokens: int, | |
| max_side: int, | |
| ) -> Tuple[List[dict], list, list, str]: | |
| # No torch.cuda.* in main — we are already in GPU worker here | |
| session_rows = session_rows or [] | |
| files = files or [] | |
| if not files: | |
| gallery_pairs = [ | |
| ((r.get("thumb_path") or r.get("path")), r.get("caption","")) | |
| for r in session_rows if (r.get("thumb_path") or r.get("path")) | |
| ] | |
| return session_rows, gallery_pairs, _rows_to_table(session_rows), f"Saved • {time.strftime('%H:%M:%S')}" | |
| for f in files: | |
| path = f if isinstance(f, str) else getattr(f, "name", None) or getattr(f, "path", None) | |
| if not path or not os.path.exists(path): | |
| continue | |
| try: | |
| im = Image.open(path).convert("RGB") | |
| except Exception: | |
| continue | |
| im = resize_for_model(im, max_side) | |
| cap = caption_once(im, instr_text, temp, top_p, max_tokens) | |
| cap = apply_shape_aliases(cap) | |
| s = load_settings() | |
| cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end","")) | |
| filename = os.path.basename(path) | |
| thumb = ensure_thumb(path, 256) | |
| session_rows.append({"filename": filename, "caption": cap, "path": path, "thumb_path": thumb}) | |
| save_session(session_rows) | |
| gallery_pairs = [ | |
| ((r.get("thumb_path") or r.get("path")), r.get("caption","")) | |
| for r in session_rows if (r.get("thumb_path") or r.get("path")) | |
| ] | |
| return session_rows, gallery_pairs, _rows_to_table(session_rows), f"Saved • {time.strftime('%H:%M:%S')}" | |
| def caption_single(img: Image.Image, instr: str) -> str: | |
| if img is None: | |
| return "No image provided." | |
| s = load_settings() | |
| im = resize_for_model(img, int(s.get("max_side", 896))) | |
| cap = caption_once(im, instr, s.get("temperature",0.6), s.get("top_p",0.9), s.get("max_tokens",256)) | |
| cap = apply_shape_aliases(cap) | |
| cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end","")) | |
| return cap | |
| # tiny warmup so Spaces sees a GPU function at startup | |
| def _gpu_startup_warm(): | |
| try: | |
| im = Image.new("RGB", (64, 64), (127,127,127)) | |
| _ = caption_once(im, "Warm up.", temp=0.0, top_p=1.0, max_tokens=8) | |
| print("[ForgeCaptions] GPU warmup complete") | |
| except Exception as e: | |
| print("[ForgeCaptions] GPU warmup skipped:", e) | |
| # --------------------------------------------------------------------- | |
| # Export helpers | |
| # --------------------------------------------------------------------- | |
| def _rows_to_table(rows: List[dict]) -> list: | |
| return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])] | |
| def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]: | |
| tbl = table_value or [] | |
| new = [] | |
| for i, r in enumerate(rows or []): | |
| r = dict(r) | |
| if i < len(tbl) and len(tbl[i]) >= 2: | |
| r["filename"] = str(tbl[i][0]) if tbl[i][0] is not None else r.get("filename","") | |
| r["caption"] = str(tbl[i][1]) if tbl[i][1] is not None else r.get("caption","") | |
| new.append(r) | |
| return new | |
| def export_csv_from_table(table_value: Any) -> str: | |
| data = table_value or [] | |
| out = f"/tmp/forgecaptions_{int(time.time())}.csv" | |
| with open(out, "w", newline="", encoding="utf-8") as f: | |
| w = csv.writer(f) | |
| w.writerow(["filename", "caption"]) | |
| w.writerows(data) | |
| return out | |
| def _resize_for_excel(path: str, px: int) -> str: | |
| try: | |
| im = Image.open(path).convert("RGB") | |
| except Exception: | |
| return path | |
| w, h = im.size | |
| if max(w, h) > px: | |
| s = px / max(w, h) | |
| im = im.resize((int(w*s), int(h*s)), Image.LANCZOS) | |
| base = os.path.basename(path) | |
| out_path = os.path.join(EXCEL_THUMB_DIR, f"{os.path.splitext(base)[0]}_{px}px.jpg") | |
| try: | |
| im.save(out_path, "JPEG", quality=85, optimize=True) | |
| return out_path | |
| except Exception: | |
| return path | |
| def export_excel_with_thumbs(table_value: Any, session_rows: List[dict], thumb_px: int) -> str: | |
| try: | |
| from openpyxl import Workbook | |
| from openpyxl.drawing.image import Image as XLImage | |
| except Exception as e: | |
| raise RuntimeError("Excel export requires 'openpyxl' in requirements.txt.") from e | |
| caption_by_file = {} | |
| for row in (table_value or []): | |
| if not row: | |
| continue | |
| fn = str(row[0]) if len(row) > 0 else "" | |
| cap = str(row[1]) if len(row) > 1 and row[1] is not None else "" | |
| if fn: | |
| caption_by_file[fn] = cap | |
| wb = Workbook(); ws = wb.active; ws.title = "ForgeCaptions" | |
| ws.append(["image", "filename", "caption"]) | |
| ws.column_dimensions["A"].width = 24 | |
| ws.column_dimensions["B"].width = 42 | |
| ws.column_dimensions["C"].width = 100 | |
| row_h = int(int(thumb_px) * 0.75) # px→pt-ish | |
| r_i = 2 | |
| for r in (session_rows or []): | |
| fn = r.get("filename","") | |
| cap = caption_by_file.get(fn, r.get("caption","")) | |
| ws.cell(row=r_i, column=2, value=fn) | |
| ws.cell(row=r_i, column=3, value=cap) | |
| img_path = r.get("thumb_path") or r.get("path") | |
| if img_path and os.path.exists(img_path): | |
| try: | |
| resized = _resize_for_excel(img_path, int(thumb_px)) | |
| xlimg = XLImage(resized) | |
| ws.add_image(xlimg, f"A{r_i}") | |
| ws.row_dimensions[r_i].height = row_h | |
| except Exception: | |
| pass | |
| r_i += 1 | |
| out = f"/tmp/forgecaptions_{int(time.time())}.xlsx" | |
| wb.save(out) | |
| return out | |
| def sync_table_to_session(table_value: Any, session_rows: List[dict]) -> Tuple[List[dict], list, str]: | |
| session_rows = _table_to_rows(table_value, session_rows or []) | |
| save_session(session_rows) | |
| gallery_pairs = [ | |
| ((r.get("thumb_path") or r.get("path")), r.get("caption","")) | |
| for r in session_rows if (r.get("thumb_path") or r.get("path")) | |
| ] | |
| return session_rows, gallery_pairs, f"Saved • {time.strftime('%H:%M:%S')}" | |
| # --------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------- | |
| BASE_CSS = """ | |
| :root{--galleryW:50%;--tableW:50%;} | |
| .gradio-container{max-width:100%!important} | |
| .cf-hero{display:flex; align-items:center; justify-content:center; gap:16px; | |
| margin:4px 0 12px; text-align:center;} | |
| .cf-hero > div { text-align:center; } | |
| .cf-logo{height:calc(3.25rem + 3 * 1.1rem + 18px);width:auto;object-fit:contain} | |
| .cf-title{margin:0;font-size:3.25rem;line-height:1;letter-spacing:.2px} | |
| .cf-sub{margin:6px 0 0;font-size:1.1rem;color:#cfd3da} | |
| .cf-scroll{max-height:70vh; overflow-y:auto; border:1px solid #e6e6e6; border-radius:10px; padding:8px} | |
| #cfGal .grid > div { height: 96px; } | |
| """ | |
| with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo: | |
| # ensure Spaces sees a GPU function at start (without touching CUDA in main) | |
| demo.load(_gpu_startup_warm, inputs=None, outputs=None) | |
| settings = load_settings() | |
| settings["styles"] = [s for s in settings.get("styles", []) if s in STYLE_OPTIONS] or ["Character training (long)"] | |
| gr.HTML(value=f""" | |
| <div class="cf-hero"> | |
| {logo_b64_img()} | |
| <div> | |
| <h1 class="cf-title">ForgeCaptions</h1> | |
| <div class="cf-sub">Batch captioning</div> | |
| <div class="cf-sub">Scrollable editor & autosave</div> | |
| <div class="cf-sub">CSV / Excel export</div> | |
| </div> | |
| </div> | |
| <hr>""") | |
| # ── Controls | |
| with gr.Group(): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| with gr.Accordion("Caption style (choose one or combine)", open=True): | |
| style_checks = gr.CheckboxGroup( | |
| choices=STYLE_OPTIONS, | |
| value=settings.get("styles", ["Character training (long)"]), | |
| label=None | |
| ) | |
| with gr.Accordion("Extra options", open=False): | |
| extra_opts = gr.CheckboxGroup( | |
| choices=[NAME_OPTION] + EXTRA_CHOICES, | |
| value=settings.get("extras", []), | |
| label=None | |
| ) | |
| with gr.Accordion("Name & Prefix/Suffix", open=False): | |
| name_input = gr.Textbox(label="Person / Character Name", value=settings.get("name", "")) | |
| trig = gr.Textbox(label="Trigger word", value=settings.get("trigger","")) | |
| add_start = gr.Textbox(label="Add text to start", value=settings.get("begin","")) | |
| add_end = gr.Textbox(label="Add text to end", value=settings.get("end","")) | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Model Instructions", open=False): | |
| instruction_preview = gr.Textbox(label=None, lines=12) | |
| dataset_name = gr.Textbox(label="Dataset name (export title prefix)", | |
| value=settings.get("dataset_name", "forgecaptions")) | |
| max_side = gr.Slider(256, 1024, settings.get("max_side", 896), step=32, label="Max side (resize)") | |
| excel_thumb_px = gr.Slider(64, 256, value=settings.get("excel_thumb_px", 128), | |
| step=8, label="Excel thumbnail size (px)") | |
| gr.Markdown("Generation settings: temperature 0.6 • top-p 0.9 • max tokens 256") | |
| def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms): | |
| instr = final_instruction(styles or ["Character training (long)"], extra or [], name_value) | |
| cfg = load_settings() | |
| cfg.update({ | |
| "styles": styles or ["Character training (long)"], | |
| "extras": extra or [], | |
| "name": name_value, | |
| "trigger": trigv, "begin": begv, "end": endv, | |
| "excel_thumb_px": int(excel_px), | |
| "max_side": int(ms), | |
| }) | |
| save_settings(cfg) | |
| return instr | |
| for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side]: | |
| comp.change(_refresh_instruction, | |
| inputs=[style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side], | |
| outputs=[instruction_preview]) | |
| demo.load(lambda s,e,n: final_instruction(s or ["Character training (long)"], e or [], n), | |
| inputs=[style_checks, extra_opts, name_input], outputs=[instruction_preview]) | |
| # ── Shape Aliases (improved) | |
| with gr.Accordion("Shape Aliases", open=False): | |
| gr.Markdown( | |
| "### 🔷 Shape Aliases\n" | |
| "Replace literal **shape tokens** in captions with a preferred **name**.\n\n" | |
| "**How to use:**\n" | |
| "- Left column = a single token **or** comma/pipe-separated synonyms, e.g. `penis, cock | phallic`\n" | |
| "- Right column = replacement name, e.g. `family-emblem`\n\n" | |
| "Matches are case-insensitive, use whole words, and also catch `*-shaped` (e.g., `diamond-shaped`).\n" | |
| "Multi-word phrases are supported." | |
| ) | |
| init_rows, init_enabled = get_shape_alias_rows_ui_defaults() | |
| enable_aliases = gr.Checkbox(label="Enable shape alias replacements", value=init_enabled) | |
| alias_table = gr.Dataframe( | |
| headers=["shape (literal token)", "name to insert"], | |
| value=init_rows, | |
| col_count=(2, "fixed"), | |
| row_count=(max(1, len(init_rows)), "dynamic"), | |
| datatype=["str","str"], | |
| type="array", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| add_row_btn = gr.Button("+ Add row", variant="secondary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| save_btn = gr.Button("💾 Save", variant="primary") | |
| save_status = gr.Markdown("") | |
| def _add_row(cur): | |
| cur = (cur or []) + [["", ""]] | |
| return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic")) | |
| def _clear_rows(): | |
| return gr.update(value=[["", ""]], row_count=(1, "dynamic")) | |
| add_row_btn.click(_add_row, inputs=[alias_table], outputs=[alias_table]) | |
| clear_btn.click(_clear_rows, outputs=[alias_table]) | |
| save_btn.click(save_shape_alias_rows, inputs=[enable_aliases, alias_table], outputs=[save_status, alias_table]) | |
| # ── Tabs: Single & Batch | |
| with gr.Tabs(): | |
| with gr.Tab("Single"): | |
| input_image_single = gr.Image(type="pil", label="Input Image", height=512, width=512) | |
| single_caption_btn = gr.Button("Caption") | |
| single_caption_out = gr.Textbox(label="Caption (single)") | |
| single_caption_btn.click( | |
| caption_single, | |
| inputs=[input_image_single, instruction_preview], | |
| outputs=[single_caption_out] | |
| ) | |
| with gr.Tab("Batch"): | |
| with gr.Accordion("Uploaded images", open=True): | |
| input_files = gr.File(label="Drop images", file_types=["image"], file_count="multiple", type="filepath") | |
| run_button = gr.Button("Caption batch", variant="primary") | |
| # ── Results + Table (same position) | |
| rows_state = gr.State(load_session()) | |
| autosave_md = gr.Markdown("Ready.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gallery = gr.Gallery( | |
| label="Results (image + caption)", | |
| show_label=True, | |
| columns=3, | |
| height=520, | |
| elem_id="cfGal", | |
| elem_classes=["cf-scroll"] | |
| ) | |
| with gr.Column(scale=1): | |
| table = gr.Dataframe( | |
| label="Editable captions (whole session)", | |
| value=_rows_to_table(load_session()), | |
| headers=["filename", "caption"], | |
| interactive=True, | |
| wrap=True, | |
| elem_id="cfTable", | |
| elem_classes=["cf-scroll"] | |
| ) | |
| # Exports | |
| with gr.Row(): | |
| with gr.Column(): | |
| export_csv_btn = gr.Button("Export CSV") | |
| csv_file = gr.File(label="CSV file", visible=False) | |
| with gr.Column(): | |
| export_xlsx_btn = gr.Button("Export Excel (.xlsx) with thumbnails") | |
| xlsx_file = gr.File(label="Excel file", visible=False) | |
| def _initial_gallery(rows): | |
| rows = rows or [] | |
| return [((r.get("thumb_path") or r.get("path")), r.get("caption","")) | |
| for r in rows if (r.get("thumb_path") or r.get("path"))] | |
| demo.load(_initial_gallery, inputs=[rows_state], outputs=[gallery]) | |
| # Scroll sync | |
| gr.HTML(""" | |
| <script> | |
| (function () { | |
| function findGalleryScrollRoot() { | |
| const host = document.querySelector("#cfGal"); | |
| if (!host) return null; | |
| return host.querySelector(".grid") || host.querySelector("[data-testid='gallery']") || host; | |
| } | |
| function findTableScrollRoot() { | |
| const host = document.querySelector("#cfTable"); | |
| if (!host) return null; | |
| return host.querySelector(".wrap") || | |
| host.querySelector(".dataframe-wrap") || | |
| (host.querySelector("table") ? host.querySelector("table").parentElement : null) || | |
| host; | |
| } | |
| function syncScroll(a, b) { | |
| if (!a || !b) return; | |
| let lock = false; | |
| const onScrollA = () => { if (lock) return; lock = true; b.scrollTop = a.scrollTop; lock = false; }; | |
| const onScrollB = () => { if (lock) return; lock = true; a.scrollTop = b.scrollTop; lock = false; }; | |
| a.addEventListener("scroll", onScrollA, { passive: true }); | |
| b.addEventListener("scroll", onScrollB, { passive: true }); | |
| } | |
| let tries = 0; | |
| const timer = setInterval(() => { | |
| tries++; | |
| const gal = findGalleryScrollRoot(); | |
| const tab = findTableScrollRoot(); | |
| if (gal && tab) { | |
| const H = Math.min(gal.clientHeight || 520, tab.clientHeight || 520); | |
| gal.style.maxHeight = H + "px"; | |
| gal.style.overflowY = "auto"; | |
| tab.style.maxHeight = H + "px"; | |
| tab.style.overflowY = "auto"; | |
| syncScroll(gal, tab); | |
| clearInterval(timer); | |
| } | |
| if (tries > 20) clearInterval(timer); | |
| }, 100); | |
| })(); | |
| </script> | |
| """) | |
| # Batch run → rows + gallery + table | |
| def _run_click(files, rows, instr, ms): | |
| s = load_settings() | |
| t = s.get("temperature", 0.6) | |
| p = s.get("top_p", 0.9) | |
| m = s.get("max_tokens", 256) | |
| new_rows, gal, tbl, stamp = run_batch(files, rows or [], instr, t, p, m, int(ms)) | |
| return new_rows, gal, tbl, stamp | |
| run_button.click( | |
| _run_click, | |
| inputs=[input_files, rows_state, instruction_preview, max_side], | |
| outputs=[rows_state, gallery, table, autosave_md] | |
| ) | |
| # Table edits sync | |
| table.change( | |
| sync_table_to_session, | |
| inputs=[table, rows_state], | |
| outputs=[rows_state, gallery, autosave_md] | |
| ) | |
| # Exports | |
| export_csv_btn.click( | |
| lambda tbl: (export_csv_from_table(tbl), gr.update(visible=True)), | |
| inputs=[table], outputs=[csv_file, csv_file] | |
| ) | |
| export_xlsx_btn.click( | |
| lambda tbl, rows, px: (export_excel_with_thumbs(tbl, rows or [], int(px)), gr.update(visible=True)), | |
| inputs=[table, rows_state, excel_thumb_px], outputs=[xlsx_file, xlsx_file] | |
| ) | |
| # Launch (SSR off for stability on Spaces) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=64).launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.getenv("PORT", "7860")), | |
| ssr_mode=False, | |
| debug=True, | |
| show_error=True, | |
| ) | |