griddev commited on
Commit
7c69cda
Β·
verified Β·
1 Parent(s): 72e1b79

Deploy Streamlit Space app

Browse files
Files changed (1) hide show
  1. app.py +63 -39
app.py CHANGED
@@ -307,6 +307,18 @@ def _resolve_weight_source_for_model(model_name, requested_source):
307
  return "base", f"{short_name} has no '{requested_source}' weights. Using base."
308
 
309
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  # ─────────────────────────────────────────────────────────────────────────────
311
  # Cached Model Loaders (with weight_source support)
312
  # ─────────────────────────────────────────────────────────────────────────────
@@ -760,29 +772,30 @@ with tab_caption:
760
 
761
  with col_result:
762
  if uploaded_file and generate_btn:
763
- resolved_source, warning_msg = _resolve_weight_source_for_model(
764
- selected_model, weight_source
765
- )
766
- if warning_msg:
767
- st.warning(warning_msg)
768
- with st.spinner(
769
- f"Loading {MODEL_SHORT[selected_model]} ({resolved_source}) + generating…"
770
- ):
771
- try:
772
- caption = generate_caption(
773
- selected_model, selected_mode, image,
774
- num_beams=num_beams,
775
- max_new_tokens=max_new_tokens,
776
- length_penalty=length_penalty,
777
- weight_source=resolved_source,
778
- )
779
- except Exception as e:
780
- st.error(f"Generation error: {e}")
781
- caption = None
 
782
 
783
  if caption:
784
  render_caption_card(
785
- selected_model, caption, resolved_source,
786
  num_beams, length_penalty, max_new_tokens,
787
  container=st,
788
  )
@@ -844,16 +857,20 @@ with tab_compare:
844
  compare_image = Image.open(compare_file).convert("RGB")
845
 
846
  resolved_sources = {}
847
- warnings = []
848
  for model_key in MODEL_KEYS:
849
- resolved, warning_msg = _resolve_weight_source_for_model(
850
- model_key, weight_source
851
- )
852
- resolved_sources[model_key] = resolved
853
- if warning_msg:
854
- warnings.append(warning_msg)
855
- for msg in sorted(set(warnings)):
856
- st.warning(msg)
 
 
 
 
 
857
 
858
  # Generate captions from all 4 models
859
  results = {}
@@ -879,17 +896,24 @@ with tab_compare:
879
  else:
880
  mode = "Baseline (Full Attention)"
881
 
882
- try:
883
- cap = generate_caption(
884
- model_key, mode, compare_image,
885
- num_beams=num_beams,
886
- max_new_tokens=max_new_tokens,
887
- length_penalty=length_penalty,
888
- weight_source=resolved_sources.get(model_key, weight_source),
889
  )
890
- results[model_key] = cap
891
- except Exception as e:
892
- results[model_key] = f"[Error: {e}]"
 
 
 
 
 
 
 
 
 
893
 
894
  progress.progress(1.0, text="βœ… All models complete!")
895
 
 
307
  return "base", f"{short_name} has no '{requested_source}' weights. Using base."
308
 
309
 
310
+ def _finetuned_available_for_model(model_name, requested_source):
311
+ if requested_source == "base":
312
+ return True
313
+ model_dir = MODEL_DIR.get(model_name)
314
+ if not model_dir or model_dir in DISABLE_FINETUNE_FOR:
315
+ return False
316
+ if _has_finetuned(model_dir, requested_source):
317
+ return True
318
+ _ensure_model_outputs_available(model_dir)
319
+ return _has_finetuned(model_dir, requested_source)
320
+
321
+
322
  # ─────────────────────────────────────────────────────────────────────────────
323
  # Cached Model Loaders (with weight_source support)
324
  # ─────────────────────────────────────────────────────────────────────────────
 
772
 
773
  with col_result:
774
  if uploaded_file and generate_btn:
775
+ if not _finetuned_available_for_model(selected_model, weight_source):
776
+ st.error(
777
+ f"{MODEL_SHORT[selected_model]} does not have '{weight_source}' weights."
778
+ )
779
+ caption = None
780
+ else:
781
+ with st.spinner(
782
+ f"Loading {MODEL_SHORT[selected_model]} ({weight_source}) + generating…"
783
+ ):
784
+ try:
785
+ caption = generate_caption(
786
+ selected_model, selected_mode, image,
787
+ num_beams=num_beams,
788
+ max_new_tokens=max_new_tokens,
789
+ length_penalty=length_penalty,
790
+ weight_source=weight_source,
791
+ )
792
+ except Exception as e:
793
+ st.error(f"Generation error: {e}")
794
+ caption = None
795
 
796
  if caption:
797
  render_caption_card(
798
+ selected_model, caption, weight_source,
799
  num_beams, length_penalty, max_new_tokens,
800
  container=st,
801
  )
 
857
  compare_image = Image.open(compare_file).convert("RGB")
858
 
859
  resolved_sources = {}
 
860
  for model_key in MODEL_KEYS:
861
+ resolved_sources[model_key] = weight_source
862
+ if weight_source != "base":
863
+ missing = [
864
+ MODEL_SHORT[m]
865
+ for m in MODEL_KEYS
866
+ if not _finetuned_available_for_model(m, weight_source)
867
+ ]
868
+ if missing:
869
+ st.warning(
870
+ "Missing fine-tuned weights for: "
871
+ + ", ".join(missing)
872
+ + ". Marking those results as unavailable."
873
+ )
874
 
875
  # Generate captions from all 4 models
876
  results = {}
 
896
  else:
897
  mode = "Baseline (Full Attention)"
898
 
899
+ if not _finetuned_available_for_model(model_key, weight_source):
900
+ results[model_key] = (
901
+ f"[Fine-tuned '{weight_source}' weights not available]"
902
+ if weight_source != "base"
903
+ else "[Not available]"
 
 
904
  )
905
+ else:
906
+ try:
907
+ cap = generate_caption(
908
+ model_key, mode, compare_image,
909
+ num_beams=num_beams,
910
+ max_new_tokens=max_new_tokens,
911
+ length_penalty=length_penalty,
912
+ weight_source=weight_source,
913
+ )
914
+ results[model_key] = cap
915
+ except Exception as e:
916
+ results[model_key] = f"[Error: {e}]"
917
 
918
  progress.progress(1.0, text="βœ… All models complete!")
919