griddev commited on
Commit
bba7394
Β·
verified Β·
1 Parent(s): 6455831

Deploy Streamlit Space app

Browse files
Files changed (1) hide show
  1. app.py +124 -64
app.py CHANGED
@@ -165,37 +165,54 @@ DEFAULT_SHAKESPEARE_WEIGHTS = "./shakespeare_transformer.pt"
165
  WEIGHTS_REPO_ID = os.getenv("WEIGHTS_REPO_ID", "griddev/vlm-caption-weights")
166
  WEIGHTS_CACHE_DIR = os.getenv("WEIGHTS_CACHE_DIR", "./weights_bundle")
167
 
 
 
 
 
 
 
 
 
 
168
 
169
- def _resolve_weight_paths():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  output_root = DEFAULT_OUTPUT_ROOT
171
  shakespeare_file = DEFAULT_SHAKESPEARE_FILE
172
  shakespeare_weights = DEFAULT_SHAKESPEARE_WEIGHTS
173
- local_ready = (
174
- os.path.isdir(output_root)
175
- and os.path.exists(shakespeare_file)
176
- and os.path.exists(shakespeare_weights)
177
  )
178
- if local_ready:
179
  return output_root, shakespeare_file, shakespeare_weights
180
 
181
  try:
182
- from huggingface_hub import snapshot_download
183
- snapshot_download(
184
- repo_id=WEIGHTS_REPO_ID,
185
- repo_type="model",
186
- local_dir=WEIGHTS_CACHE_DIR,
187
- local_dir_use_symlinks=False,
188
- allow_patterns=[
189
- "outputs/*",
190
- "outputs/**/*",
191
- "input.txt",
192
- "shakespeare_transformer.pt",
193
- ],
194
- )
195
- candidate_output_root = os.path.join(WEIGHTS_CACHE_DIR, "outputs")
196
- candidate_shakespeare_file = os.path.join(WEIGHTS_CACHE_DIR, "input.txt")
197
  candidate_shakespeare_weights = os.path.join(
198
- WEIGHTS_CACHE_DIR, "shakespeare_transformer.pt"
199
  )
200
  if os.path.isdir(candidate_output_root):
201
  output_root = candidate_output_root
@@ -209,9 +226,6 @@ def _resolve_weight_paths():
209
  return output_root, shakespeare_file, shakespeare_weights
210
 
211
 
212
- OUTPUT_ROOT, SHAKESPEARE_FILE, SHAKESPEARE_WEIGHTS_PATH = _resolve_weight_paths()
213
-
214
-
215
  # ─────────────────────────────────────────────────────────────────────────────
216
  # Device
217
  # ─────────────────────────────────────────────────────────────────────────────
@@ -228,12 +242,36 @@ def get_device():
228
 
229
  def _has_finetuned(model_dir, subdir):
230
  """Check if a fine-tuned checkpoint exists for a given model + subdir."""
231
- path = os.path.join(OUTPUT_ROOT, model_dir, subdir)
232
- return os.path.isdir(path) and len(os.listdir(path)) > 0
233
-
234
-
235
- def _ckpt_path(model_dir, subdir):
236
- return os.path.join(OUTPUT_ROOT, model_dir, subdir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
 
239
  # ─────────────────────────────────────────────────────────────────────────────
@@ -250,7 +288,10 @@ def load_blip(weight_source="base"):
250
  "Salesforce/blip-image-captioning-base")
251
 
252
  if weight_source != "base":
253
- ckpt = _ckpt_path("blip", weight_source)
 
 
 
254
  if os.path.isdir(ckpt) and os.listdir(ckpt):
255
  try:
256
  loaded = BlipForConditionalGeneration.from_pretrained(ckpt)
@@ -276,7 +317,10 @@ def load_vit_gpt2(weight_source="base"):
276
  model.config.pad_token_id = tokenizer.pad_token_id
277
 
278
  if weight_source != "base":
279
- ckpt = _ckpt_path("vit_gpt2", weight_source)
 
 
 
280
  if os.path.isdir(ckpt) and os.listdir(ckpt):
281
  try:
282
  loaded = VisionEncoderDecoderModel.from_pretrained(ckpt)
@@ -298,7 +342,10 @@ def load_git(weight_source="base"):
298
  model = AutoModelForCausalLM.from_pretrained(model_id)
299
 
300
  if weight_source != "base":
301
- ckpt = _ckpt_path("git", weight_source)
 
 
 
302
  if os.path.isdir(ckpt) and os.listdir(ckpt):
303
  try:
304
  loaded = AutoModelForCausalLM.from_pretrained(ckpt)
@@ -317,9 +364,12 @@ def load_custom_vlm(weight_source="base"):
317
  from config import CFG
318
  device = get_device()
319
  cfg = CFG()
320
- cfg.output_root = OUTPUT_ROOT
321
- cfg.shakespeare_file = SHAKESPEARE_FILE
322
- cfg.shakespeare_weights_path = SHAKESPEARE_WEIGHTS_PATH
 
 
 
323
 
324
  if not os.path.exists(cfg.shakespeare_file):
325
  return None, None, None, None, device
@@ -518,36 +568,27 @@ with st.sidebar:
518
  st.markdown("### πŸ”¬ VLM Caption Lab")
519
  st.markdown("---")
520
 
 
 
 
521
  # ── Weight Source ─────────────────────────────────────────────────────────
522
- weight_options = {
523
- "πŸ”΅ Base (Pretrained)": "base",
524
- "🟒 Fine-tuned (Best)": "best",
525
- "🟑 Fine-tuned (Latest)": "latest",
526
- }
 
 
527
  weight_choice = st.radio(
528
  "**Weight Source**", list(weight_options.keys()), index=0,
529
  help="Base = HuggingFace pretrained. Best/Latest = your fine-tuned checkpoints."
530
  )
531
  weight_source = weight_options[weight_choice]
532
-
533
- # Show availability indicators
534
- ft_status = []
535
- for mdl_dir, mdl_name in [("blip", "BLIP"), ("vit_gpt2", "ViT-GPT2"),
536
- ("git", "GIT"), ("custom_vlm", "Custom VLM")]:
537
- has_best = _has_finetuned(mdl_dir, "best")
538
- has_latest = _has_finetuned(mdl_dir, "latest")
539
- if has_best or has_latest:
540
- ft_status.append(f" βœ… {mdl_name}")
541
- else:
542
- ft_status.append(f" ⬜ {mdl_name}")
543
- if weight_source != "base":
544
- st.caption("Fine-tuned checkpoints:\n" + "\n".join(ft_status))
545
 
546
  st.markdown("---")
547
 
548
- # ── Architecture Selector ─────────────────────────────────────────────────
549
- selected_model = st.selectbox("**Architecture**", MODEL_KEYS, index=0)
550
-
551
  if selected_model in ("BLIP (Multimodal Mixture Attention)",
552
  "ViT-GPT2 (Standard Cross-Attention)"):
553
  mode_options = [
@@ -678,14 +719,21 @@ with tab_caption:
678
 
679
  with col_result:
680
  if uploaded_file and generate_btn:
681
- with st.spinner(f"Loading {MODEL_SHORT[selected_model]} ({weight_source}) + generating…"):
 
 
 
 
 
 
 
682
  try:
683
  caption = generate_caption(
684
  selected_model, selected_mode, image,
685
  num_beams=num_beams,
686
  max_new_tokens=max_new_tokens,
687
  length_penalty=length_penalty,
688
- weight_source=weight_source,
689
  )
690
  except Exception as e:
691
  st.error(f"Generation error: {e}")
@@ -693,7 +741,7 @@ with tab_caption:
693
 
694
  if caption:
695
  render_caption_card(
696
- selected_model, caption, weight_source,
697
  num_beams, length_penalty, max_new_tokens,
698
  container=st,
699
  )
@@ -745,7 +793,7 @@ with tab_compare:
745
  with col_ctrl:
746
  if compare_file:
747
  compare_image = Image.open(compare_file).convert("RGB")
748
- st.image(compare_image, caption="Comparison Image", width="stretch")
749
 
750
  compare_btn = st.button("πŸš€ Compare All 4 Models",
751
  disabled=(compare_file is None or not is_common_mode),
@@ -754,6 +802,18 @@ with tab_compare:
754
  if compare_file and compare_btn:
755
  compare_image = Image.open(compare_file).convert("RGB")
756
 
 
 
 
 
 
 
 
 
 
 
 
 
757
  # Generate captions from all 4 models
758
  results = {}
759
  progress = st.progress(0, text="Starting comparison...")
@@ -784,7 +844,7 @@ with tab_compare:
784
  num_beams=num_beams,
785
  max_new_tokens=max_new_tokens,
786
  length_penalty=length_penalty,
787
- weight_source=weight_source,
788
  )
789
  results[model_key] = cap
790
  except Exception as e:
@@ -804,7 +864,7 @@ with tab_compare:
804
  cap = results.get(model_key, "[Not available]")
805
  with col:
806
  render_caption_card(
807
- model_key, cap, weight_source,
808
  num_beams, length_penalty, max_new_tokens,
809
  container=st,
810
  card_class="compare-card",
 
165
  WEIGHTS_REPO_ID = os.getenv("WEIGHTS_REPO_ID", "griddev/vlm-caption-weights")
166
  WEIGHTS_CACHE_DIR = os.getenv("WEIGHTS_CACHE_DIR", "./weights_bundle")
167
 
168
+ MODEL_DIR = {
169
+ "BLIP (Multimodal Mixture Attention)": "blip",
170
+ "ViT-GPT2 (Standard Cross-Attention)": "vit_gpt2",
171
+ "GIT (Zero Cross-Attention)": "git",
172
+ "Custom VLM (Shakespeare Prefix)": "custom_vlm",
173
+ }
174
+
175
+
176
+ OUTPUT_ROOT = DEFAULT_OUTPUT_ROOT
177
 
178
+
179
+ @st.cache_resource(show_spinner=False)
180
+ def _download_weights(need_outputs: bool, need_shakespeare: bool) -> str:
181
+ from huggingface_hub import snapshot_download
182
+ allow_patterns = []
183
+ if need_outputs:
184
+ allow_patterns += ["outputs/*", "outputs/**/*"]
185
+ if need_shakespeare:
186
+ allow_patterns += ["input.txt", "shakespeare_transformer.pt"]
187
+ if not allow_patterns:
188
+ return WEIGHTS_CACHE_DIR
189
+ return snapshot_download(
190
+ repo_id=WEIGHTS_REPO_ID,
191
+ repo_type="model",
192
+ local_dir=WEIGHTS_CACHE_DIR,
193
+ local_dir_use_symlinks=False,
194
+ allow_patterns=allow_patterns,
195
+ )
196
+
197
+
198
+ def _resolve_weight_paths(need_outputs: bool, need_shakespeare: bool):
199
  output_root = DEFAULT_OUTPUT_ROOT
200
  shakespeare_file = DEFAULT_SHAKESPEARE_FILE
201
  shakespeare_weights = DEFAULT_SHAKESPEARE_WEIGHTS
202
+
203
+ have_outputs = os.path.isdir(output_root) and len(os.listdir(output_root)) > 0
204
+ have_shakespeare = (
205
+ os.path.exists(shakespeare_file) and os.path.exists(shakespeare_weights)
206
  )
207
+ if (not need_outputs or have_outputs) and (not need_shakespeare or have_shakespeare):
208
  return output_root, shakespeare_file, shakespeare_weights
209
 
210
  try:
211
+ cache_dir = _download_weights(need_outputs, need_shakespeare)
212
+ candidate_output_root = os.path.join(cache_dir, "outputs")
213
+ candidate_shakespeare_file = os.path.join(cache_dir, "input.txt")
 
 
 
 
 
 
 
 
 
 
 
 
214
  candidate_shakespeare_weights = os.path.join(
215
+ cache_dir, "shakespeare_transformer.pt"
216
  )
217
  if os.path.isdir(candidate_output_root):
218
  output_root = candidate_output_root
 
226
  return output_root, shakespeare_file, shakespeare_weights
227
 
228
 
 
 
 
229
  # ─────────────────────────────────────────────────────────────────────────────
230
  # Device
231
  # ─────────────────────────────────────────────────────────────────────────────
 
242
 
243
  def _has_finetuned(model_dir, subdir):
244
  """Check if a fine-tuned checkpoint exists for a given model + subdir."""
245
+ candidates = [
246
+ os.path.join(DEFAULT_OUTPUT_ROOT, model_dir, subdir),
247
+ os.path.join(WEIGHTS_CACHE_DIR, "outputs", model_dir, subdir),
248
+ ]
249
+ for path in candidates:
250
+ if os.path.isdir(path) and len(os.listdir(path)) > 0:
251
+ return True
252
+ return False
253
+
254
+
255
+ def _ckpt_path(output_root, model_dir, subdir):
256
+ return os.path.join(output_root, model_dir, subdir)
257
+
258
+
259
+ def _resolve_weight_source_for_model(model_name, requested_source):
260
+ if requested_source == "base":
261
+ return requested_source, None
262
+ model_dir = MODEL_DIR.get(model_name)
263
+ if not model_dir:
264
+ return requested_source, None
265
+ if _has_finetuned(model_dir, requested_source):
266
+ return requested_source, None
267
+ _resolve_weight_paths(
268
+ need_outputs=True,
269
+ need_shakespeare=(model_dir == "custom_vlm"),
270
+ )
271
+ if _has_finetuned(model_dir, requested_source):
272
+ return requested_source, None
273
+ short_name = MODEL_SHORT.get(model_name, model_name)
274
+ return "base", f"{short_name} has no '{requested_source}' weights. Using base."
275
 
276
 
277
  # ─────────────────────────────────────────────────────────────────────────────
 
288
  "Salesforce/blip-image-captioning-base")
289
 
290
  if weight_source != "base":
291
+ output_root, _, _ = _resolve_weight_paths(
292
+ need_outputs=True, need_shakespeare=False
293
+ )
294
+ ckpt = _ckpt_path(output_root, "blip", weight_source)
295
  if os.path.isdir(ckpt) and os.listdir(ckpt):
296
  try:
297
  loaded = BlipForConditionalGeneration.from_pretrained(ckpt)
 
317
  model.config.pad_token_id = tokenizer.pad_token_id
318
 
319
  if weight_source != "base":
320
+ output_root, _, _ = _resolve_weight_paths(
321
+ need_outputs=True, need_shakespeare=False
322
+ )
323
+ ckpt = _ckpt_path(output_root, "vit_gpt2", weight_source)
324
  if os.path.isdir(ckpt) and os.listdir(ckpt):
325
  try:
326
  loaded = VisionEncoderDecoderModel.from_pretrained(ckpt)
 
342
  model = AutoModelForCausalLM.from_pretrained(model_id)
343
 
344
  if weight_source != "base":
345
+ output_root, _, _ = _resolve_weight_paths(
346
+ need_outputs=True, need_shakespeare=False
347
+ )
348
+ ckpt = _ckpt_path(output_root, "git", weight_source)
349
  if os.path.isdir(ckpt) and os.listdir(ckpt):
350
  try:
351
  loaded = AutoModelForCausalLM.from_pretrained(ckpt)
 
364
  from config import CFG
365
  device = get_device()
366
  cfg = CFG()
367
+ output_root, shakespeare_file, shakespeare_weights = _resolve_weight_paths(
368
+ need_outputs=(weight_source != "base"), need_shakespeare=True
369
+ )
370
+ cfg.output_root = output_root
371
+ cfg.shakespeare_file = shakespeare_file
372
+ cfg.shakespeare_weights_path = shakespeare_weights
373
 
374
  if not os.path.exists(cfg.shakespeare_file):
375
  return None, None, None, None, device
 
568
  st.markdown("### πŸ”¬ VLM Caption Lab")
569
  st.markdown("---")
570
 
571
+ # ── Architecture Selector ─────────────────────────────────────────────────
572
+ selected_model = st.selectbox("**Architecture**", MODEL_KEYS, index=0)
573
+
574
  # ── Weight Source ─────────────────────────────────────────────────────────
575
+ model_dir = MODEL_DIR.get(selected_model)
576
+ weight_options = {"πŸ”΅ Base (Pretrained)": "base"}
577
+ if model_dir and _has_finetuned(model_dir, "best"):
578
+ weight_options["🟒 Fine-tuned (Best)"] = "best"
579
+ if model_dir and _has_finetuned(model_dir, "latest"):
580
+ weight_options["🟑 Fine-tuned (Latest)"] = "latest"
581
+
582
  weight_choice = st.radio(
583
  "**Weight Source**", list(weight_options.keys()), index=0,
584
  help="Base = HuggingFace pretrained. Best/Latest = your fine-tuned checkpoints."
585
  )
586
  weight_source = weight_options[weight_choice]
587
+ if len(weight_options) == 1:
588
+ st.caption("Fine-tuned weights not available for this model.")
 
 
 
 
 
 
 
 
 
 
 
589
 
590
  st.markdown("---")
591
 
 
 
 
592
  if selected_model in ("BLIP (Multimodal Mixture Attention)",
593
  "ViT-GPT2 (Standard Cross-Attention)"):
594
  mode_options = [
 
719
 
720
  with col_result:
721
  if uploaded_file and generate_btn:
722
+ resolved_source, warning_msg = _resolve_weight_source_for_model(
723
+ selected_model, weight_source
724
+ )
725
+ if warning_msg:
726
+ st.warning(warning_msg)
727
+ with st.spinner(
728
+ f"Loading {MODEL_SHORT[selected_model]} ({resolved_source}) + generating…"
729
+ ):
730
  try:
731
  caption = generate_caption(
732
  selected_model, selected_mode, image,
733
  num_beams=num_beams,
734
  max_new_tokens=max_new_tokens,
735
  length_penalty=length_penalty,
736
+ weight_source=resolved_source,
737
  )
738
  except Exception as e:
739
  st.error(f"Generation error: {e}")
 
741
 
742
  if caption:
743
  render_caption_card(
744
+ selected_model, caption, resolved_source,
745
  num_beams, length_penalty, max_new_tokens,
746
  container=st,
747
  )
 
793
  with col_ctrl:
794
  if compare_file:
795
  compare_image = Image.open(compare_file).convert("RGB")
796
+ st.image(compare_image, caption="Comparison Image", use_column_width=True)
797
 
798
  compare_btn = st.button("πŸš€ Compare All 4 Models",
799
  disabled=(compare_file is None or not is_common_mode),
 
802
  if compare_file and compare_btn:
803
  compare_image = Image.open(compare_file).convert("RGB")
804
 
805
+ resolved_sources = {}
806
+ warnings = []
807
+ for model_key in MODEL_KEYS:
808
+ resolved, warning_msg = _resolve_weight_source_for_model(
809
+ model_key, weight_source
810
+ )
811
+ resolved_sources[model_key] = resolved
812
+ if warning_msg:
813
+ warnings.append(warning_msg)
814
+ for msg in sorted(set(warnings)):
815
+ st.warning(msg)
816
+
817
  # Generate captions from all 4 models
818
  results = {}
819
  progress = st.progress(0, text="Starting comparison...")
 
844
  num_beams=num_beams,
845
  max_new_tokens=max_new_tokens,
846
  length_penalty=length_penalty,
847
+ weight_source=resolved_sources.get(model_key, weight_source),
848
  )
849
  results[model_key] = cap
850
  except Exception as e:
 
864
  cap = results.get(model_key, "[Not available]")
865
  with col:
866
  render_caption_card(
867
+ model_key, cap, resolved_sources.get(model_key, weight_source),
868
  num_beams, length_penalty, max_new_tokens,
869
  container=st,
870
  card_class="compare-card",