JS6969 commited on
Commit
5cd2a27
·
verified ·
1 Parent(s): 6120301

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +335 -247
app.py CHANGED
@@ -1,9 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os, io, csv, time, json, base64, re
2
  from typing import List, Tuple, Dict, Any
3
 
4
- # ---------------------------------------------------------------------
5
- # Caching
6
- # ---------------------------------------------------------------------
7
  os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
8
  os.makedirs(os.environ["HF_HOME"], exist_ok=True)
9
 
@@ -12,37 +27,52 @@ from PIL import Image
12
  import torch
13
  from transformers import LlavaForConditionalGeneration, AutoProcessor
14
 
15
- # ── HF Spaces GPU decorator (no-op on CPU/local) ─────────────────────
16
  try:
17
  import spaces
18
  gpu = spaces.GPU()
19
- except Exception: # local/CPU
20
- def gpu(f): return f
21
 
 
 
 
 
22
  APP_DIR = os.getcwd()
23
- SESSION_FILE = "/tmp/session.json"
24
- SETTINGS_FILE = "/tmp/cf_settings.json"
25
- JOURNAL_FILE = "/tmp/cf_journal.json"
26
  THUMB_CACHE = os.path.expanduser("~/.cache/forgecaptions/thumbs")
27
  EXCEL_THUMB_DIR = "/tmp/forge_excel_thumbs"
28
  os.makedirs(THUMB_CACHE, exist_ok=True)
29
  os.makedirs(EXCEL_THUMB_DIR, exist_ok=True)
30
 
31
- # ---------------------------------------------------------------------
32
- # Model identifiers
33
- # ---------------------------------------------------------------------
34
  MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
35
 
36
- # Load the processor on CPU (safe in stateless env)
37
- processor = AutoProcessor.from_pretrained(MODEL_PATH)
 
 
 
 
 
38
 
39
- # Lazy GPU/CPU model (created inside GPU worker only)
 
 
 
 
 
 
40
  _MODEL = None
41
  _DEVICE = "cpu"
42
  _DTYPE = torch.float32
43
 
44
  def get_model():
45
- """Create/reuse model; only call this from inside @gpu functions."""
 
 
 
46
  global _MODEL, _DEVICE, _DTYPE
47
  if _MODEL is None:
48
  if torch.cuda.is_available():
@@ -52,7 +82,7 @@ def get_model():
52
  MODEL_PATH,
53
  torch_dtype=_DTYPE,
54
  low_cpu_mem_usage=True,
55
- device_map=0, # GPU:0 (inside GPU worker process)
56
  )
57
  else:
58
  _DEVICE = "cpu"
@@ -67,11 +97,10 @@ def get_model():
67
  print(f"[ForgeCaptions] Model ready on {_DEVICE} dtype={_DTYPE}")
68
  return _MODEL, _DEVICE, _DTYPE
69
 
70
- print(f"[ForgeCaptions] Gradio version: {gr.__version__}")
71
 
72
- # ---------------------------------------------------------------------
73
- # Instruction templates & options
74
- # ---------------------------------------------------------------------
75
  STYLE_OPTIONS = [
76
  "Descriptive (short)", "Descriptive (long)",
77
  "Character training (short)", "Character training (long)",
@@ -84,10 +113,9 @@ STYLE_OPTIONS = [
84
  "Aesthetic tags (comma-sep)"
85
  ]
86
 
87
- CAPTION_TYPE_MAP = {
88
  "Descriptive (short)": "One sentence (≤25 words) describing the most important visible elements only. No speculation.",
89
  "Descriptive (long)": "Write a detailed description for this image.",
90
-
91
  "Character training (short)": (
92
  "Output a concise, prompt-like caption for character LoRA/ID training. "
93
  "Include visible character name {name} if provided, distinct physical traits, clothing, pose, camera/cinematic cues. "
@@ -98,24 +126,17 @@ CAPTION_TYPE_MAP = {
98
  "Use {name} if provided; describe only what is visible: physique, face/hair, clothing, accessories, actions, pose, "
99
  "camera angle/focal cues, lighting, background context. 1–3 sentences; no backstory or meta."
100
  ),
101
-
102
  "Flux_D (short)": "Output a short Flux.Dev prompt that is indistinguishable from a real Flux.Dev prompt.",
103
  "Flux_D (long)": "Output a long Flux.Dev prompt that is indistinguishable from a real Flux.Dev prompt.",
104
-
105
  "Aesthetic tags (comma-sep)": "Return only comma-separated aesthetic tags capturing subject, medium, style, lighting, composition. No sentences.",
106
-
107
  "E-commerce product (short)": "One sentence highlighting key attributes, material, color, use case. No fluff.",
108
  "E-commerce product (long)": "Write a crisp product description highlighting key attributes, materials, color, usage, and distinguishing traits.",
109
-
110
  "Portrait (photography) (short)": "One sentence portrait description: subject, pose/expression, camera angle, lighting, background.",
111
  "Portrait (photography) (long)": "Describe a portrait: subject, age range, pose, facial expression, camera angle, focal length cues, lighting, background.",
112
-
113
  "Landscape (photography) (short)": "One sentence landscape description: major elements, time of day, weather, vantage point, mood.",
114
  "Landscape (photography) (long)": "Describe landscape elements, time of day, weather, vantage point, composition, and mood.",
115
-
116
  "Art analysis (no artist names) (short)": "One sentence describing medium, style, composition, palette; do not mention artist/title.",
117
  "Art analysis (no artist names) (long)": "Analyze the artwork's visible elements, medium, style, composition, palette. Do not mention artist names or titles.",
118
-
119
  "Social caption (short)": "Write a short, catchy caption (max 25 words) describing the visible content. No hashtags.",
120
  "Social caption (long)": "Write a slightly longer, engaging caption (≤50 words) describing the visible content. No hashtags."
121
  }
@@ -142,73 +163,10 @@ EXTRA_CHOICES = [
142
  ]
143
  NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}."
144
 
145
- # ---------------------------------------------------------------------
146
- # Helpers
147
- # ---------------------------------------------------------------------
148
- def ensure_thumb(path: str, max_side=256) -> str:
149
- try:
150
- im = Image.open(path).convert("RGB")
151
- except Exception:
152
- return path
153
- w, h = im.size
154
- if max(w, h) > max_side:
155
- s = max_side / max(w, h)
156
- im = im.resize((int(w*s), int(h*s)), Image.LANCZOS)
157
- base = os.path.basename(path)
158
- out_path = os.path.join(THUMB_CACHE, os.path.splitext(base)[0] + f"_thumb_{max_side}.jpg")
159
- try:
160
- im.save(out_path, "JPEG", quality=85, optimize=True)
161
- return out_path
162
- except Exception:
163
- return path
164
-
165
- def resize_for_model(im: Image.Image, max_side: int) -> Image.Image:
166
- w, h = im.size
167
- if max(w, h) <= max_side:
168
- return im
169
- s = max_side / max(w, h)
170
- return im.resize((int(w*s), int(h*s)), Image.LANCZOS)
171
-
172
- def apply_prefix_suffix(caption: str, trigger_word: str, begin_text: str, end_text: str) -> str:
173
- parts = []
174
- if trigger_word.strip():
175
- parts.append(trigger_word.strip())
176
- if begin_text.strip():
177
- parts.append(begin_text.strip())
178
- parts.append(caption.strip())
179
- if end_text.strip():
180
- parts.append(end_text.strip())
181
- return " ".join([p for p in parts if p])
182
-
183
- # Instruction + caption
184
- def final_instruction(style_list: List[str], extra_opts: List[str], name_value: str) -> str:
185
- styles = style_list or ["Descriptive (short)"]
186
- parts = [CAPTION_TYPE_MAP.get(s, "") for s in styles]
187
- core = " ".join(p for p in parts if p).strip()
188
- if extra_opts:
189
- core += " " + " ".join(extra_opts)
190
- if NAME_OPTION in (extra_opts or []):
191
- core = core.replace("{name}", (name_value or "{NAME}").strip())
192
- return core
193
-
194
- def logo_b64_img() -> str:
195
- candidates = [
196
- os.path.join(APP_DIR, "forgecaptions-logo.png"),
197
- os.path.join(APP_DIR, "captionforge-logo.png"),
198
- "/home/user/app/forgecaptions-logo.png",
199
- "forgecaptions-logo.png",
200
- "captionforge-logo.png",
201
- ]
202
- for p in candidates:
203
- if os.path.exists(p):
204
- with open(p, "rb") as f:
205
- b64 = base64.b64encode(f.read()).decode("ascii")
206
- return f"<img src='data:image/png;base64,{b64}' alt='ForgeCaptions' class='cf-logo'>"
207
- return ""
208
 
209
- # ---------------------------------------------------------------------
210
- # Persistence
211
- # ---------------------------------------------------------------------
212
  def save_session(rows: List[dict]):
213
  with open(SESSION_FILE, "w", encoding="utf-8") as f:
214
  json.dump(rows, f, ensure_ascii=False, indent=2)
@@ -229,13 +187,14 @@ def load_settings() -> dict:
229
  cfg = json.load(f)
230
  else:
231
  cfg = {}
 
232
  defaults = {
233
  "dataset_name": "forgecaptions",
234
  "temperature": 0.6,
235
  "top_p": 0.9,
236
  "max_tokens": 256,
237
  "max_side": 896,
238
- "styles": ["Character training (long)"], # default changed
239
  "extras": [],
240
  "name": "",
241
  "trigger": "",
@@ -247,23 +206,9 @@ def load_settings() -> dict:
247
  }
248
  for k, v in defaults.items():
249
  cfg.setdefault(k, v)
250
-
251
- legacy_map = {
252
- "Descriptive": "Descriptive (short)",
253
- "LoRA (Flux_D Realism)": "LoRA (Flux_D Realism) (short)",
254
- "Portrait (photography)": "Portrait (photography) (short)",
255
- "Landscape (photography)": "Landscape (photography) (short)",
256
- "Art analysis (no artist names)": "Art analysis (no artist names) (short)",
257
- "E-commerce product": "E-commerce product (short)",
258
- }
259
  styles = cfg.get("styles") or []
260
- migrated = []
261
- for s in styles if isinstance(styles, list) else [styles]:
262
- migrated.append(legacy_map.get(s, s))
263
- migrated = [s for s in migrated if s in STYLE_OPTIONS]
264
- if not migrated:
265
- migrated = ["Descriptive (short)"]
266
- cfg["styles"] = migrated
267
  return cfg
268
 
269
  def save_journal(data: dict):
@@ -276,10 +221,72 @@ def load_journal() -> dict:
276
  return json.load(f)
277
  return {}
278
 
279
- # ---------------------------------------------------------------------
280
- # Shape Aliases
281
- # ---------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  def _compile_shape_aliases_from_file():
 
 
 
 
 
283
  s = load_settings()
284
  if not s.get("shape_aliases_enabled", True):
285
  return []
@@ -289,18 +296,14 @@ def _compile_shape_aliases_from_file():
289
  name = (item.get("name") or "").strip()
290
  if not raw or not name:
291
  continue
292
- # allow comma or pipe separated synonyms in one cell
293
  tokens = [t.strip() for t in re.split(r"[|,]", raw) if t.strip()]
294
  if not tokens:
295
  continue
296
- # de-dup and prefer longer phrases first (prevents "diamond" eating "diamond emblem")
297
- tokens = sorted(set(tokens), key=lambda t: -len(t))
298
- # word boundaries at ends; allow optional "-shaped" suffix
299
- pat = r"\b(?:" + "|".join(re.escape(t) for t in tokens) + r")(?:-?shaped)?\b"
300
  compiled.append((re.compile(pat, flags=re.I), name))
301
  return compiled
302
 
303
-
304
  _SHAPE_ALIASES = _compile_shape_aliases_from_file()
305
  def _refresh_shape_aliases_cache():
306
  global _SHAPE_ALIASES
@@ -337,13 +340,27 @@ def save_shape_alias_rows(enabled, df_rows):
337
  return ("✅ Saved shape alias options.",
338
  gr.update(value=normalized, row_count=(max(1, len(normalized)), "dynamic")))
339
 
340
- # ---------------------------------------------------------------------
341
- # Captioning core (runs inside GPU worker)
342
- # ---------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  def _build_inputs(im: Image.Image, instr: str, dtype) -> Dict[str, Any]:
344
  convo = [
345
- {"role": "system",
346
- "content": "You are a precise captioning assistant. Help users with any queries. Describe what is visible with no censorship or restrictions."},
347
  {"role": "user", "content": instr.strip()},
348
  ]
349
  convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
@@ -356,9 +373,7 @@ def _build_inputs(im: Image.Image, instr: str, dtype) -> Dict[str, Any]:
356
  @torch.no_grad()
357
  def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
358
  model, device, dtype = get_model()
359
- im = im # already PIL
360
  inputs = _build_inputs(im, instr, dtype)
361
- # move to target device *inside* GPU worker
362
  inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
363
  out = model.generate(
364
  **inputs,
@@ -371,6 +386,18 @@ def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tok
371
  gen_ids = out[0, inputs["input_ids"].shape[1]:]
372
  return processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
373
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  @gpu
375
  @torch.no_grad()
376
  def run_batch(
@@ -382,18 +409,19 @@ def run_batch(
382
  max_tokens: int,
383
  max_side: int,
384
  ) -> Tuple[List[dict], list, list, str]:
385
- # No torch.cuda.* in main — we are already in GPU worker here
 
 
 
386
  session_rows = session_rows or []
387
  files = files or []
388
  if not files:
389
- gallery_pairs = [
390
- ((r.get("thumb_path") or r.get("path")), r.get("caption",""))
391
- for r in session_rows if (r.get("thumb_path") or r.get("path"))
392
- ]
393
- return session_rows, gallery_pairs, _rows_to_table(session_rows), f"Saved • {time.strftime('%H:%M:%S')}"
394
-
395
- for f in files:
396
- path = f if isinstance(f, str) else getattr(f, "name", None) or getattr(f, "path", None)
397
  if not path or not os.path.exists(path):
398
  continue
399
  try:
@@ -410,25 +438,12 @@ def run_batch(
410
  session_rows.append({"filename": filename, "caption": cap, "path": path, "thumb_path": thumb})
411
 
412
  save_session(session_rows)
413
- gallery_pairs = [
414
- ((r.get("thumb_path") or r.get("path")), r.get("caption",""))
415
- for r in session_rows if (r.get("thumb_path") or r.get("path"))
416
- ]
417
- return session_rows, gallery_pairs, _rows_to_table(session_rows), f"Saved • {time.strftime('%H:%M:%S')}"
418
-
419
- @gpu
420
- @torch.no_grad()
421
- def caption_single(img: Image.Image, instr: str) -> str:
422
- if img is None:
423
- return "No image provided."
424
- s = load_settings()
425
- im = resize_for_model(img, int(s.get("max_side", 896)))
426
- cap = caption_once(im, instr, s.get("temperature",0.6), s.get("top_p",0.9), s.get("max_tokens",256))
427
- cap = apply_shape_aliases(cap)
428
- cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end",""))
429
- return cap
430
 
431
- # tiny warmup so Spaces sees a GPU function at startup
432
  @gpu
433
  @torch.no_grad()
434
  def _gpu_startup_warm():
@@ -439,9 +454,10 @@ def _gpu_startup_warm():
439
  except Exception as e:
440
  print("[ForgeCaptions] GPU warmup skipped:", e)
441
 
442
- # ---------------------------------------------------------------------
443
- # Export helpers
444
- # ---------------------------------------------------------------------
 
445
  def _rows_to_table(rows: List[dict]) -> list:
446
  return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])]
447
 
@@ -460,9 +476,7 @@ def export_csv_from_table(table_value: Any) -> str:
460
  data = table_value or []
461
  out = f"/tmp/forgecaptions_{int(time.time())}.csv"
462
  with open(out, "w", newline="", encoding="utf-8") as f:
463
- w = csv.writer(f)
464
- w.writerow(["filename", "caption"])
465
- w.writerows(data)
466
  return out
467
 
468
  def _resize_for_excel(path: str, px: int) -> str:
@@ -504,11 +518,11 @@ def export_excel_with_thumbs(table_value: Any, session_rows: List[dict], thumb_p
504
  ws.column_dimensions["B"].width = 42
505
  ws.column_dimensions["C"].width = 100
506
 
507
- row_h = int(int(thumb_px) * 0.75) # px→pt-ish
 
508
  r_i = 2
509
  for r in (session_rows or []):
510
- fn = r.get("filename","")
511
- cap = caption_by_file.get(fn, r.get("caption",""))
512
  ws.cell(row=r_i, column=2, value=fn)
513
  ws.cell(row=r_i, column=3, value=cap)
514
  img_path = r.get("thumb_path") or r.get("path")
@@ -526,53 +540,63 @@ def export_excel_with_thumbs(table_value: Any, session_rows: List[dict], thumb_p
526
  wb.save(out)
527
  return out
528
 
529
- def sync_table_to_session(table_value: Any, session_rows: List[dict]) -> Tuple[List[dict], list, str]:
530
- session_rows = _table_to_rows(table_value, session_rows or [])
531
- save_session(session_rows)
532
- gallery_pairs = [
533
- ((r.get("thumb_path") or r.get("path")), r.get("caption",""))
534
- for r in session_rows if (r.get("thumb_path") or r.get("path"))
535
- ]
536
- return session_rows, gallery_pairs, f"Saved • {time.strftime('%H:%M:%S')}"
537
 
538
- # ---------------------------------------------------------------------
539
- # UI
540
- # ---------------------------------------------------------------------
541
  BASE_CSS = """
542
  :root{--galleryW:50%;--tableW:50%;}
543
  .gradio-container{max-width:100%!important}
544
- .cf-hero{display:flex; align-items:center; justify-content:center; gap:16px;
545
- margin:4px 0 12px; text-align:center;}
546
- .cf-hero > div { text-align:center; }
547
- .cf-logo{height:calc(3.25rem + 3 * 1.1rem + 18px);width:auto;object-fit:contain}
 
 
 
 
 
 
548
  .cf-title{margin:0;font-size:3.25rem;line-height:1;letter-spacing:.2px}
549
  .cf-sub{margin:6px 0 0;font-size:1.1rem;color:#cfd3da}
 
 
550
  .cf-scroll{max-height:70vh; overflow-y:auto; border:1px solid #e6e6e6; border-radius:10px; padding:8px}
551
  #cfGal .grid > div { height: 96px; }
552
  """
553
 
554
  with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
555
- # ensure Spaces sees a GPU function at start (without touching CUDA in main)
556
  demo.load(_gpu_startup_warm, inputs=None, outputs=None)
557
 
558
- settings = load_settings()
559
- settings["styles"] = [s for s in settings.get("styles", []) if s in STYLE_OPTIONS] or ["Character training (long)"]
560
-
561
  gr.HTML(value=f"""
562
  <div class="cf-hero">
563
  {logo_b64_img()}
564
- <div>
565
  <h1 class="cf-title">ForgeCaptions</h1>
566
  <div class="cf-sub">Batch captioning</div>
567
  <div class="cf-sub">Scrollable editor & autosave</div>
568
  <div class="cf-sub">CSV / Excel export</div>
569
  </div>
570
  </div>
571
- <hr>""")
 
 
 
 
 
 
 
 
 
 
 
572
 
573
- # ── Controls
574
  with gr.Group():
575
  with gr.Row():
 
576
  with gr.Column(scale=2):
577
  with gr.Accordion("Caption style (choose one or combine)", open=True):
578
  style_checks = gr.CheckboxGroup(
@@ -592,6 +616,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
592
  add_start = gr.Textbox(label="Add text to start", value=settings.get("begin",""))
593
  add_end = gr.Textbox(label="Add text to end", value=settings.get("end",""))
594
 
 
595
  with gr.Column(scale=1):
596
  with gr.Accordion("Model Instructions", open=False):
597
  instruction_preview = gr.Textbox(label=None, lines=12)
@@ -600,8 +625,15 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
600
  max_side = gr.Slider(256, 1024, settings.get("max_side", 896), step=32, label="Max side (resize)")
601
  excel_thumb_px = gr.Slider(64, 256, value=settings.get("excel_thumb_px", 128),
602
  step=8, label="Excel thumbnail size (px)")
603
- gr.Markdown("Generation settings: temperature 0.6 • top-p 0.9 • max tokens 256")
604
-
 
 
 
 
 
 
 
605
  def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms):
606
  instr = final_instruction(styles or ["Character training (long)"], extra or [], name_value)
607
  cfg = load_settings()
@@ -624,43 +656,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
624
  demo.load(lambda s,e,n: final_instruction(s or ["Character training (long)"], e or [], n),
625
  inputs=[style_checks, extra_opts, name_input], outputs=[instruction_preview])
626
 
627
- # ── Shape Aliases (improved)
628
- with gr.Accordion("Shape Aliases", open=False):
629
- gr.Markdown(
630
- "### 🔷 Shape Aliases\n"
631
- "Replace literal **shape tokens** in captions with a preferred **name**.\n\n"
632
- "**How to use:**\n"
633
- "- Left column = a single token **or** comma/pipe-separated synonyms, e.g. `penis, cock | phallic`\n"
634
- "- Right column = replacement name, e.g. `family-emblem`\n\n"
635
- "Matches are case-insensitive, use whole words, and also catch `*-shaped` (e.g., `diamond-shaped`).\n"
636
- "Multi-word phrases are supported."
637
- )
638
- init_rows, init_enabled = get_shape_alias_rows_ui_defaults()
639
- enable_aliases = gr.Checkbox(label="Enable shape alias replacements", value=init_enabled)
640
- alias_table = gr.Dataframe(
641
- headers=["shape (literal token)", "name to insert"],
642
- value=init_rows,
643
- col_count=(2, "fixed"),
644
- row_count=(max(1, len(init_rows)), "dynamic"),
645
- datatype=["str","str"],
646
- type="array",
647
- interactive=True
648
- )
649
- with gr.Row():
650
- add_row_btn = gr.Button("+ Add row", variant="secondary")
651
- clear_btn = gr.Button("Clear", variant="secondary")
652
- save_btn = gr.Button("💾 Save", variant="primary")
653
- save_status = gr.Markdown("")
654
- def _add_row(cur):
655
- cur = (cur or []) + [["", ""]]
656
- return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic"))
657
- def _clear_rows():
658
- return gr.update(value=[["", ""]], row_count=(1, "dynamic"))
659
- add_row_btn.click(_add_row, inputs=[alias_table], outputs=[alias_table])
660
- clear_btn.click(_clear_rows, outputs=[alias_table])
661
- save_btn.click(save_shape_alias_rows, inputs=[enable_aliases, alias_table], outputs=[save_status, alias_table])
662
-
663
- # ── Tabs: Single & Batch
664
  with gr.Tabs():
665
  with gr.Tab("Single"):
666
  input_image_single = gr.Image(type="pil", label="Input Image", height=512, width=512)
@@ -677,9 +673,10 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
677
  input_files = gr.File(label="Drop images", file_types=["image"], file_count="multiple", type="filepath")
678
  run_button = gr.Button("Caption batch", variant="primary")
679
 
680
- # ── Results + Table (same position)
681
  rows_state = gr.State(load_session())
682
  autosave_md = gr.Markdown("Ready.")
 
683
 
684
  with gr.Row():
685
  with gr.Column(scale=1):
@@ -702,7 +699,14 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
702
  elem_classes=["cf-scroll"]
703
  )
704
 
705
- # Exports
 
 
 
 
 
 
 
706
  with gr.Row():
707
  with gr.Column():
708
  export_csv_btn = gr.Button("Export CSV")
@@ -711,13 +715,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
711
  export_xlsx_btn = gr.Button("Export Excel (.xlsx) with thumbnails")
712
  xlsx_file = gr.File(label="Excel file", visible=False)
713
 
714
- def _initial_gallery(rows):
715
- rows = rows or []
716
- return [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
717
- for r in rows if (r.get("thumb_path") or r.get("path"))]
718
- demo.load(_initial_gallery, inputs=[rows_state], outputs=[gallery])
719
-
720
- # Scroll sync
721
  gr.HTML("""
722
  <script>
723
  (function () {
@@ -762,29 +760,116 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
762
  </script>
763
  """)
764
 
765
- # Batch run rows + gallery + table
766
- def _run_click(files, rows, instr, ms):
 
 
 
 
 
767
  s = load_settings()
768
- t = s.get("temperature", 0.6)
769
- p = s.get("top_p", 0.9)
770
- m = s.get("max_tokens", 256)
771
- new_rows, gal, tbl, stamp = run_batch(files, rows or [], instr, t, p, m, int(ms))
772
- return new_rows, gal, tbl, stamp
 
 
 
 
 
 
 
 
 
 
 
 
 
773
 
774
  run_button.click(
775
  _run_click,
776
- inputs=[input_files, rows_state, instruction_preview, max_side],
777
- outputs=[rows_state, gallery, table, autosave_md]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
778
  )
779
 
780
- # Table edits sync
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781
  table.change(
782
  sync_table_to_session,
783
  inputs=[table, rows_state],
784
  outputs=[rows_state, gallery, autosave_md]
785
  )
786
 
787
- # Exports
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
788
  export_csv_btn.click(
789
  lambda tbl: (export_csv_from_table(tbl), gr.update(visible=True)),
790
  inputs=[table], outputs=[csv_file, csv_file]
@@ -794,7 +879,10 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
794
  inputs=[table, rows_state, excel_thumb_px], outputs=[xlsx_file, xlsx_file]
795
  )
796
 
797
- # Launch (SSR off for stability on Spaces)
 
 
 
798
  if __name__ == "__main__":
799
  demo.queue(max_size=64).launch(
800
  server_name="0.0.0.0",
 
1
+ # =====================================================================
2
+ # ForgeCaptions - Gradio app for single & batch image captioning
3
+ # =====================================================================
4
+ # CHANGELOG (this version)
5
+ # - GPU-safe: all CUDA only inside @spaces.GPU functions.
6
+ # - Restored: Single tab + Batch chunking (Auto / All-at-once / Manual step).
7
+ # - Shape Aliases: supports comma/pipe-separated synonyms per row.
8
+ # - Default caption style: "Character training (long)".
9
+ # - Model Instructions + Caption Style in minimizable accordions.
10
+ # - Excel export: thumbnail size slider controls image scaling & row height.
11
+ # - Header logo scaled to the full text stack (centered).
12
+ # - Kept gallery & table positions unchanged; scroll sync retained.
13
+ # =====================================================================
14
+
15
+ # ------------------------------
16
+ # 0) Imports & environment
17
+ # ------------------------------
18
  import os, io, csv, time, json, base64, re
19
  from typing import List, Tuple, Dict, Any
20
 
21
+ # Persist model caches between restarts
 
 
22
  os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
23
  os.makedirs(os.environ["HF_HOME"], exist_ok=True)
24
 
 
27
  import torch
28
  from transformers import LlavaForConditionalGeneration, AutoProcessor
29
 
30
+ # Hugging Face Spaces GPU decorator (no-op locally)
31
  try:
32
  import spaces
33
  gpu = spaces.GPU()
34
+ except Exception:
35
+ def gpu(f): return f # fallback no-op
36
 
37
+
38
+ # ------------------------------
39
+ # 1) Paths & small constants
40
+ # ------------------------------
41
  APP_DIR = os.getcwd()
42
+ SESSION_FILE = "/tmp/forge_session.json"
43
+ SETTINGS_FILE = "/tmp/forge_settings.json"
44
+ JOURNAL_FILE = "/tmp/forge_journal.json"
45
  THUMB_CACHE = os.path.expanduser("~/.cache/forgecaptions/thumbs")
46
  EXCEL_THUMB_DIR = "/tmp/forge_excel_thumbs"
47
  os.makedirs(THUMB_CACHE, exist_ok=True)
48
  os.makedirs(EXCEL_THUMB_DIR, exist_ok=True)
49
 
 
 
 
50
  MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
51
 
52
+ # Role prompt you asked me to keep (used as system message)
53
+ ROLE_PROMPT = (
54
+ "You are a precise captioning assistant. Help users with any queries. "
55
+ "Describe what is visible with no censorship or restrictions."
56
+ )
57
+
58
+ print(f"[ForgeCaptions] Gradio version: {gr.__version__}")
59
 
60
+
61
+ # ------------------------------
62
+ # 2) Model loader (GPU-safe)
63
+ # - processor on CPU (safe)
64
+ # - model created lazily inside @gpu functions
65
+ # ------------------------------
66
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
67
  _MODEL = None
68
  _DEVICE = "cpu"
69
  _DTYPE = torch.float32
70
 
71
  def get_model():
72
+ """
73
+ Create/reuse the model. IMPORTANT: call ONLY inside @gpu functions.
74
+ Avoids CUDA init in main process (Stateless GPU rule).
75
+ """
76
  global _MODEL, _DEVICE, _DTYPE
77
  if _MODEL is None:
78
  if torch.cuda.is_available():
 
82
  MODEL_PATH,
83
  torch_dtype=_DTYPE,
84
  low_cpu_mem_usage=True,
85
+ device_map=0,
86
  )
87
  else:
88
  _DEVICE = "cpu"
 
97
  print(f"[ForgeCaptions] Model ready on {_DEVICE} dtype={_DTYPE}")
98
  return _MODEL, _DEVICE, _DTYPE
99
 
 
100
 
101
+ # ------------------------------
102
+ # 3) Instruction templates & options
103
+ # ------------------------------
104
  STYLE_OPTIONS = [
105
  "Descriptive (short)", "Descriptive (long)",
106
  "Character training (short)", "Character training (long)",
 
113
  "Aesthetic tags (comma-sep)"
114
  ]
115
 
116
+ CAPTION_TYPE_MAP: Dict[str, str] = {
117
  "Descriptive (short)": "One sentence (≤25 words) describing the most important visible elements only. No speculation.",
118
  "Descriptive (long)": "Write a detailed description for this image.",
 
119
  "Character training (short)": (
120
  "Output a concise, prompt-like caption for character LoRA/ID training. "
121
  "Include visible character name {name} if provided, distinct physical traits, clothing, pose, camera/cinematic cues. "
 
126
  "Use {name} if provided; describe only what is visible: physique, face/hair, clothing, accessories, actions, pose, "
127
  "camera angle/focal cues, lighting, background context. 1–3 sentences; no backstory or meta."
128
  ),
 
129
  "Flux_D (short)": "Output a short Flux.Dev prompt that is indistinguishable from a real Flux.Dev prompt.",
130
  "Flux_D (long)": "Output a long Flux.Dev prompt that is indistinguishable from a real Flux.Dev prompt.",
 
131
  "Aesthetic tags (comma-sep)": "Return only comma-separated aesthetic tags capturing subject, medium, style, lighting, composition. No sentences.",
 
132
  "E-commerce product (short)": "One sentence highlighting key attributes, material, color, use case. No fluff.",
133
  "E-commerce product (long)": "Write a crisp product description highlighting key attributes, materials, color, usage, and distinguishing traits.",
 
134
  "Portrait (photography) (short)": "One sentence portrait description: subject, pose/expression, camera angle, lighting, background.",
135
  "Portrait (photography) (long)": "Describe a portrait: subject, age range, pose, facial expression, camera angle, focal length cues, lighting, background.",
 
136
  "Landscape (photography) (short)": "One sentence landscape description: major elements, time of day, weather, vantage point, mood.",
137
  "Landscape (photography) (long)": "Describe landscape elements, time of day, weather, vantage point, composition, and mood.",
 
138
  "Art analysis (no artist names) (short)": "One sentence describing medium, style, composition, palette; do not mention artist/title.",
139
  "Art analysis (no artist names) (long)": "Analyze the artwork's visible elements, medium, style, composition, palette. Do not mention artist names or titles.",
 
140
  "Social caption (short)": "Write a short, catchy caption (max 25 words) describing the visible content. No hashtags.",
141
  "Social caption (long)": "Write a slightly longer, engaging caption (≤50 words) describing the visible content. No hashtags."
142
  }
 
163
  ]
164
  NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}."
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ # ------------------------------
168
+ # 4) Persistence helpers (settings/session/journal)
169
+ # ------------------------------
170
  def save_session(rows: List[dict]):
171
  with open(SESSION_FILE, "w", encoding="utf-8") as f:
172
  json.dump(rows, f, ensure_ascii=False, indent=2)
 
187
  cfg = json.load(f)
188
  else:
189
  cfg = {}
190
+ # sensible defaults for this app/version
191
  defaults = {
192
  "dataset_name": "forgecaptions",
193
  "temperature": 0.6,
194
  "top_p": 0.9,
195
  "max_tokens": 256,
196
  "max_side": 896,
197
+ "styles": ["Character training (long)"], # default you requested
198
  "extras": [],
199
  "name": "",
200
  "trigger": "",
 
206
  }
207
  for k, v in defaults.items():
208
  cfg.setdefault(k, v)
209
+ # validate styles against allowed set
 
 
 
 
 
 
 
 
210
  styles = cfg.get("styles") or []
211
+ cfg["styles"] = [s for s in (styles if isinstance(styles, list) else [styles]) if s in STYLE_OPTIONS] or ["Character training (long)"]
 
 
 
 
 
 
212
  return cfg
213
 
214
  def save_journal(data: dict):
 
221
  return json.load(f)
222
  return {}
223
 
224
+
225
+ # ------------------------------
226
+ # 5) Small utilities (thumbs, resize, prefix/suffix)
227
+ # ------------------------------
228
+ def ensure_thumb(path: str, max_side=256) -> str:
229
+ try:
230
+ im = Image.open(path).convert("RGB")
231
+ except Exception:
232
+ return path
233
+ w, h = im.size
234
+ if max(w, h) > max_side:
235
+ s = max_side / max(w, h)
236
+ im = im.resize((int(w*s), int(h*s)), Image.LANCZOS)
237
+ base = os.path.basename(path)
238
+ out_path = os.path.join(THUMB_CACHE, os.path.splitext(base)[0] + f"_thumb_{max_side}.jpg")
239
+ try:
240
+ im.save(out_path, "JPEG", quality=85, optimize=True)
241
+ return out_path
242
+ except Exception:
243
+ return path
244
+
245
+ def resize_for_model(im: Image.Image, max_side: int) -> Image.Image:
246
+ w, h = im.size
247
+ if max(w, h) <= max_side:
248
+ return im
249
+ s = max_side / max(w, h)
250
+ return im.resize((int(w*s), int(h*s)), Image.LANCZOS)
251
+
252
+ def apply_prefix_suffix(caption: str, trigger_word: str, begin_text: str, end_text: str) -> str:
253
+ parts = []
254
+ if trigger_word.strip():
255
+ parts.append(trigger_word.strip())
256
+ if begin_text.strip():
257
+ parts.append(begin_text.strip())
258
+ parts.append(caption.strip())
259
+ if end_text.strip():
260
+ parts.append(end_text.strip())
261
+ return " ".join([p for p in parts if p])
262
+
263
+ def logo_b64_img() -> str:
264
+ """
265
+ Load a PNG logo if present (falls back gracefully).
266
+ """
267
+ candidates = [
268
+ os.path.join(APP_DIR, "forgecaptions-logo.png"),
269
+ os.path.join(APP_DIR, "captionforge-logo.png"),
270
+ "forgecaptions-logo.png",
271
+ "captionforge-logo.png",
272
+ ]
273
+ for p in candidates:
274
+ if os.path.exists(p):
275
+ with open(p, "rb") as f:
276
+ b64 = base64.b64encode(f.read()).decode("ascii")
277
+ return f"<img src='data:image/png;base64,{b64}' alt='ForgeCaptions' class='cf-logo'>"
278
+ return ""
279
+
280
+
281
+ # ------------------------------
282
+ # 6) Shape Aliases (comma/pipe synonyms per row)
283
+ # ------------------------------
284
  def _compile_shape_aliases_from_file():
285
+ """
286
+ Build regex list from settings["shape_aliases"].
287
+ Left cell accepts comma OR pipe separated synonyms (multi-word OK).
288
+ Matches are case-insensitive, whole-word, and allow '-shaped' or ' shaped'.
289
+ """
290
  s = load_settings()
291
  if not s.get("shape_aliases_enabled", True):
292
  return []
 
296
  name = (item.get("name") or "").strip()
297
  if not raw or not name:
298
  continue
 
299
  tokens = [t.strip() for t in re.split(r"[|,]", raw) if t.strip()]
300
  if not tokens:
301
  continue
302
+ tokens = sorted(set(tokens), key=lambda t: -len(t)) # longest first
303
+ pat = r"\b(?:" + "|".join(re.escape(t) for t in tokens) + r")(?:[-\s]?shaped)?\b"
 
 
304
  compiled.append((re.compile(pat, flags=re.I), name))
305
  return compiled
306
 
 
307
  _SHAPE_ALIASES = _compile_shape_aliases_from_file()
308
  def _refresh_shape_aliases_cache():
309
  global _SHAPE_ALIASES
 
340
  return ("✅ Saved shape alias options.",
341
  gr.update(value=normalized, row_count=(max(1, len(normalized)), "dynamic")))
342
 
343
+
344
+ # ------------------------------
345
+ # 7) Prompt builder (instruction text shown/used for model)
346
+ # ------------------------------
347
+ def final_instruction(style_list: List[str], extra_opts: List[str], name_value: str) -> str:
348
+ styles = style_list or ["Character training (long)"]
349
+ parts = [CAPTION_TYPE_MAP.get(s, "") for s in styles]
350
+ core = " ".join(p for p in parts if p).strip()
351
+ if extra_opts:
352
+ core += " " + " ".join(extra_opts)
353
+ if NAME_OPTION in (extra_opts or []):
354
+ core = core.replace("{name}", (name_value or "{NAME}").strip())
355
+ return core
356
+
357
+
358
+ # ------------------------------
359
+ # 8) GPU caption functions
360
+ # ------------------------------
361
  def _build_inputs(im: Image.Image, instr: str, dtype) -> Dict[str, Any]:
362
  convo = [
363
+ {"role": "system", "content": ROLE_PROMPT},
 
364
  {"role": "user", "content": instr.strip()},
365
  ]
366
  convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
 
373
  @torch.no_grad()
374
  def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
375
  model, device, dtype = get_model()
 
376
  inputs = _build_inputs(im, instr, dtype)
 
377
  inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
378
  out = model.generate(
379
  **inputs,
 
386
  gen_ids = out[0, inputs["input_ids"].shape[1]:]
387
  return processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
388
 
389
+ @gpu
390
+ @torch.no_grad()
391
+ def caption_single(img: Image.Image, instr: str) -> str:
392
+ if img is None:
393
+ return "No image provided."
394
+ s = load_settings()
395
+ im = resize_for_model(img, int(s.get("max_side", 896)))
396
+ cap = caption_once(im, instr, s.get("temperature",0.6), s.get("top_p",0.9), s.get("max_tokens",256))
397
+ cap = apply_shape_aliases(cap)
398
+ cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end",""))
399
+ return cap
400
+
401
  @gpu
402
  @torch.no_grad()
403
  def run_batch(
 
409
  max_tokens: int,
410
  max_side: int,
411
  ) -> Tuple[List[dict], list, list, str]:
412
+ """
413
+ Process a list of file paths and append results to session_rows.
414
+ Returns: updated rows, gallery_pairs, table_rows, status_text
415
+ """
416
  session_rows = session_rows or []
417
  files = files or []
418
  if not files:
419
+ gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
420
+ for r in session_rows if (r.get("thumb_path") or r.get("path"))]
421
+ table_rows = [[r.get("filename",""), r.get("caption","")] for r in session_rows]
422
+ return session_rows, gallery_pairs, table_rows, f"Saved • {time.strftime('%H:%M:%S')}"
423
+
424
+ for path in files:
 
 
425
  if not path or not os.path.exists(path):
426
  continue
427
  try:
 
438
  session_rows.append({"filename": filename, "caption": cap, "path": path, "thumb_path": thumb})
439
 
440
  save_session(session_rows)
441
+ gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
442
+ for r in session_rows if (r.get("thumb_path") or r.get("path"))]
443
+ table_rows = [[r.get("filename",""), r.get("caption","")] for r in session_rows]
444
+ return session_rows, gallery_pairs, table_rows, f"Saved • {time.strftime('%H:%M:%S')}"
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
+ # Ensure Spaces detects at least one GPU function at startup
447
  @gpu
448
  @torch.no_grad()
449
  def _gpu_startup_warm():
 
454
  except Exception as e:
455
  print("[ForgeCaptions] GPU warmup skipped:", e)
456
 
457
+
458
+ # ------------------------------
459
+ # 9) Export helpers (CSV/XLSX)
460
+ # ------------------------------
461
  def _rows_to_table(rows: List[dict]) -> list:
462
  return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])]
463
 
 
476
  data = table_value or []
477
  out = f"/tmp/forgecaptions_{int(time.time())}.csv"
478
  with open(out, "w", newline="", encoding="utf-8") as f:
479
+ w = csv.writer(f); w.writerow(["filename", "caption"]); w.writerows(data)
 
 
480
  return out
481
 
482
  def _resize_for_excel(path: str, px: int) -> str:
 
518
  ws.column_dimensions["B"].width = 42
519
  ws.column_dimensions["C"].width = 100
520
 
521
+ # Convert pixel target to approx. row points (Excel 0.75 * px)
522
+ row_h = int(int(thumb_px) * 0.75)
523
  r_i = 2
524
  for r in (session_rows or []):
525
+ fn = r.get("filename",""); cap = caption_by_file.get(fn, r.get("caption",""))
 
526
  ws.cell(row=r_i, column=2, value=fn)
527
  ws.cell(row=r_i, column=3, value=cap)
528
  img_path = r.get("thumb_path") or r.get("path")
 
540
  wb.save(out)
541
  return out
542
 
 
 
 
 
 
 
 
 
543
 
544
+ # ------------------------------
545
+ # 10) UI (Blocks)
546
+ # ------------------------------
547
  BASE_CSS = """
548
  :root{--galleryW:50%;--tableW:50%;}
549
  .gradio-container{max-width:100%!important}
550
+ .cf-hero{
551
+ display:flex; align-items:center; justify-content:center; gap:16px;
552
+ margin:4px 0 12px; text-align:center;
553
+ }
554
+ .cf-hero .cf-text{ text-align:center; }
555
+ .cf-logo{
556
+ /* Make logo fill roughly the full text stack; clamped for sanity */
557
+ height: clamp(120px, calc(3.25rem + 3 * 1.1rem + 24px), 180px);
558
+ width:auto; object-fit:contain; display:block; flex:0 0 auto;
559
+ }
560
  .cf-title{margin:0;font-size:3.25rem;line-height:1;letter-spacing:.2px}
561
  .cf-sub{margin:6px 0 0;font-size:1.1rem;color:#cfd3da}
562
+
563
+ /* Results area */
564
  .cf-scroll{max-height:70vh; overflow-y:auto; border:1px solid #e6e6e6; border-radius:10px; padding:8px}
565
  #cfGal .grid > div { height: 96px; }
566
  """
567
 
568
  with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
569
+ # Ensure Spaces sees a GPU function (without touching CUDA in main)
570
  demo.load(_gpu_startup_warm, inputs=None, outputs=None)
571
 
572
+ # ---- Header (logo + title center). Script sets logo height to match text exactly.
 
 
573
  gr.HTML(value=f"""
574
  <div class="cf-hero">
575
  {logo_b64_img()}
576
+ <div class="cf-text">
577
  <h1 class="cf-title">ForgeCaptions</h1>
578
  <div class="cf-sub">Batch captioning</div>
579
  <div class="cf-sub">Scrollable editor & autosave</div>
580
  <div class="cf-sub">CSV / Excel export</div>
581
  </div>
582
  </div>
583
+ <hr>
584
+ <script>
585
+ setTimeout(() => {{
586
+ const logo = document.querySelector(".cf-logo");
587
+ const text = document.querySelector(".cf-text");
588
+ if (logo && text) logo.style.height = text.getBoundingClientRect().height + "px";
589
+ }}, 0);
590
+ </script>
591
+ """)
592
+
593
+ # ---- Settings state (loaded once)
594
+ settings = load_settings()
595
 
596
+ # ---- Controls group (left/right columns)
597
  with gr.Group():
598
  with gr.Row():
599
+ # LEFT: Style + Extra + Name/Prefix/Suffix (accordions minimizable)
600
  with gr.Column(scale=2):
601
  with gr.Accordion("Caption style (choose one or combine)", open=True):
602
  style_checks = gr.CheckboxGroup(
 
616
  add_start = gr.Textbox(label="Add text to start", value=settings.get("begin",""))
617
  add_end = gr.Textbox(label="Add text to end", value=settings.get("end",""))
618
 
619
+ # RIGHT: Instruction preview + dataset + sliders
620
  with gr.Column(scale=1):
621
  with gr.Accordion("Model Instructions", open=False):
622
  instruction_preview = gr.Textbox(label=None, lines=12)
 
625
  max_side = gr.Slider(256, 1024, settings.get("max_side", 896), step=32, label="Max side (resize)")
626
  excel_thumb_px = gr.Slider(64, 256, value=settings.get("excel_thumb_px", 128),
627
  step=8, label="Excel thumbnail size (px)")
628
+ # Chunking controls (restored)
629
+ chunk_mode = gr.Radio(
630
+ choices=["Auto", "Manual (all at once)", "Manual (step)"],
631
+ value="Manual (step)",
632
+ label="Batch mode"
633
+ )
634
+ chunk_size = gr.Slider(1, 50, value=10, step=1, label="Chunk size")
635
+
636
+ # -- Keep instruction text in sync with controls and persist to settings
637
  def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms):
638
  instr = final_instruction(styles or ["Character training (long)"], extra or [], name_value)
639
  cfg = load_settings()
 
656
  demo.load(lambda s,e,n: final_instruction(s or ["Character training (long)"], e or [], n),
657
  inputs=[style_checks, extra_opts, name_input], outputs=[instruction_preview])
658
 
659
+ # ---- Tabs: Single & Batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
  with gr.Tabs():
661
  with gr.Tab("Single"):
662
  input_image_single = gr.Image(type="pil", label="Input Image", height=512, width=512)
 
673
  input_files = gr.File(label="Drop images", file_types=["image"], file_count="multiple", type="filepath")
674
  run_button = gr.Button("Caption batch", variant="primary")
675
 
676
+ # ---- Results (UNCHANGED POSITION): Gallery left, Table right
677
  rows_state = gr.State(load_session())
678
  autosave_md = gr.Markdown("Ready.")
679
+ remaining_state = gr.State([]) # for manual step mode
680
 
681
  with gr.Row():
682
  with gr.Column(scale=1):
 
699
  elem_classes=["cf-scroll"]
700
  )
701
 
702
+ # ---- Step panel (restored)
703
+ step_panel = gr.Group(visible=False)
704
+ with step_panel:
705
+ step_msg = gr.Markdown("")
706
+ step_next = gr.Button("Process next chunk")
707
+ step_finish = gr.Button("Finish")
708
+
709
+ # ---- Exports
710
  with gr.Row():
711
  with gr.Column():
712
  export_csv_btn = gr.Button("Export CSV")
 
715
  export_xlsx_btn = gr.Button("Export Excel (.xlsx) with thumbnails")
716
  xlsx_file = gr.File(label="Excel file", visible=False)
717
 
718
+ # ---- Scroll sync (gallery ↔ table)
 
 
 
 
 
 
719
  gr.HTML("""
720
  <script>
721
  (function () {
 
760
  </script>
761
  """)
762
 
763
+ # ---- Batch chunking logic (restored)
764
+ def _split_chunks(files, csize: int):
765
+ files = files or []
766
+ c = max(1, int(csize))
767
+ return [files[i:i+c] for i in range(0, len(files), c)]
768
+
769
+ def _tpms():
770
  s = load_settings()
771
+ return s.get("temperature", 0.6), s.get("top_p", 0.9), s.get("max_tokens", 256)
772
+
773
+ def _run_click(files, rows, instr, ms, mode, csize):
774
+ t, p, m = _tpms()
775
+ files = files or []
776
+ # Manual step → process first chunk only
777
+ if mode == "Manual (step)" and files:
778
+ chunks = _split_chunks(files, int(csize))
779
+ batch = chunks[0]
780
+ remaining = sum(chunks[1:], [])
781
+ new_rows, gal, tbl, stamp = run_batch(batch, rows or [], instr, t, p, m, int(ms))
782
+ panel_vis = gr.update(visible=bool(remaining))
783
+ msg = f"{len(remaining)} files remain. Process next chunk?"
784
+ return new_rows, gal, tbl, stamp, remaining, panel_vis, gr.update(value=msg)
785
+ # Auto / all-at-once → process everything in one go
786
+ else:
787
+ new_rows, gal, tbl, stamp = run_batch(files, rows or [], instr, t, p, m, int(ms))
788
+ return new_rows, gal, tbl, stamp, [], gr.update(visible=False), gr.update(value="")
789
 
790
  run_button.click(
791
  _run_click,
792
+ inputs=[input_files, rows_state, instruction_preview, max_side, chunk_mode, chunk_size],
793
+ outputs=[rows_state, gallery, table, autosave_md, remaining_state, step_panel, step_msg]
794
+ )
795
+
796
+ def _step_next(remain, rows, instr, ms, csize):
797
+ t, p, m = _tpms()
798
+ remain = remain or []
799
+ if not remain:
800
+ return rows, gr.update(value="No files remaining."), gr.update(visible=False), [], [], [], "Saved."
801
+ batch = remain[:int(csize)]
802
+ leftover = remain[int(csize):]
803
+ new_rows, gal, tbl, stamp = run_batch(batch, rows or [], instr, t, p, m, int(ms))
804
+ panel_vis = gr.update(visible=bool(leftover))
805
+ msg = f"{len(leftover)} files remain. Process next chunk?" if leftover else "All done."
806
+ return new_rows, msg, panel_vis, leftover, gal, tbl, stamp
807
+
808
+ step_next.click(
809
+ _step_next,
810
+ inputs=[remaining_state, rows_state, instruction_preview, max_side, chunk_size],
811
+ outputs=[rows_state, step_msg, step_panel, remaining_state, gallery, table, autosave_md]
812
  )
813
 
814
+ def _step_finish():
815
+ return gr.update(visible=False), gr.update(value=""), []
816
+
817
+ step_finish.click(
818
+ _step_finish,
819
+ inputs=None,
820
+ outputs=[step_panel, step_msg, remaining_state]
821
+ )
822
+
823
+ # ---- Table edits → persist + refresh gallery
824
+ def sync_table_to_session(table_value: Any, session_rows: List[dict]) -> Tuple[List[dict], list, str]:
825
+ session_rows = _table_to_rows(table_value, session_rows or [])
826
+ save_session(session_rows)
827
+ gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
828
+ for r in session_rows if (r.get("thumb_path") or r.get("path"))]
829
+ return session_rows, gallery_pairs, f"Saved • {time.strftime('%H:%M:%S')}"
830
+
831
  table.change(
832
  sync_table_to_session,
833
  inputs=[table, rows_state],
834
  outputs=[rows_state, gallery, autosave_md]
835
  )
836
 
837
+ # ---- Shape Aliases accordion (with examples & buttons)
838
+ with gr.Accordion("Shape Aliases", open=False):
839
+ gr.Markdown(
840
+ "### 🔷 Shape Aliases\n"
841
+ "Replace literal **shape tokens** in captions with a preferred **name**.\n\n"
842
+ "**How to use:**\n"
843
+ "- Left column = a single token **or** comma/pipe-separated synonyms, e.g. `diamond, rhombus | lozenge`\n"
844
+ "- Right column = replacement name, e.g. `starkey-emblem`\n"
845
+ "Matches are case-insensitive, whole-word, and also catch `*-shaped` or `* shaped`."
846
+ )
847
+ init_rows, init_enabled = get_shape_alias_rows_ui_defaults()
848
+ enable_aliases = gr.Checkbox(label="Enable shape alias replacements", value=init_enabled)
849
+ alias_table = gr.Dataframe(
850
+ headers=["shape (token or synonyms)", "name to insert"],
851
+ value=init_rows,
852
+ col_count=(2, "fixed"),
853
+ row_count=(max(1, len(init_rows)), "dynamic"),
854
+ datatype=["str","str"],
855
+ type="array",
856
+ interactive=True
857
+ )
858
+ with gr.Row():
859
+ add_row_btn = gr.Button("+ Add row", variant="secondary")
860
+ clear_btn = gr.Button("Clear", variant="secondary")
861
+ save_btn = gr.Button("💾 Save", variant="primary")
862
+ save_status = gr.Markdown("")
863
+ def _add_row(cur):
864
+ cur = (cur or []) + [["", ""]]
865
+ return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic"))
866
+ def _clear_rows():
867
+ return gr.update(value=[["", ""]], row_count=(1, "dynamic"))
868
+ add_row_btn.click(_add_row, inputs=[alias_table], outputs=[alias_table])
869
+ clear_btn.click(_clear_rows, outputs=[alias_table])
870
+ save_btn.click(save_shape_alias_rows, inputs=[enable_aliases, alias_table], outputs=[save_status, alias_table])
871
+
872
+ # ---- Exports
873
  export_csv_btn.click(
874
  lambda tbl: (export_csv_from_table(tbl), gr.update(visible=True)),
875
  inputs=[table], outputs=[csv_file, csv_file]
 
879
  inputs=[table, rows_state, excel_thumb_px], outputs=[xlsx_file, xlsx_file]
880
  )
881
 
882
+
883
+ # ------------------------------
884
+ # 11) Launch (SSR disabled for stability on Spaces)
885
+ # ------------------------------
886
  if __name__ == "__main__":
887
  demo.queue(max_size=64).launch(
888
  server_name="0.0.0.0",