Spaces:
Running
Running
Deploy Streamlit Space app
Browse files
app.py
CHANGED
|
@@ -307,6 +307,18 @@ def _resolve_weight_source_for_model(model_name, requested_source):
|
|
| 307 |
return "base", f"{short_name} has no '{requested_source}' weights. Using base."
|
| 308 |
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 311 |
# Cached Model Loaders (with weight_source support)
|
| 312 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -760,29 +772,30 @@ with tab_caption:
|
|
| 760 |
|
| 761 |
with col_result:
|
| 762 |
if uploaded_file and generate_btn:
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
|
|
|
| 782 |
|
| 783 |
if caption:
|
| 784 |
render_caption_card(
|
| 785 |
-
selected_model, caption,
|
| 786 |
num_beams, length_penalty, max_new_tokens,
|
| 787 |
container=st,
|
| 788 |
)
|
|
@@ -844,16 +857,20 @@ with tab_compare:
|
|
| 844 |
compare_image = Image.open(compare_file).convert("RGB")
|
| 845 |
|
| 846 |
resolved_sources = {}
|
| 847 |
-
warnings = []
|
| 848 |
for model_key in MODEL_KEYS:
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 857 |
|
| 858 |
# Generate captions from all 4 models
|
| 859 |
results = {}
|
|
@@ -879,17 +896,24 @@ with tab_compare:
|
|
| 879 |
else:
|
| 880 |
mode = "Baseline (Full Attention)"
|
| 881 |
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
length_penalty=length_penalty,
|
| 888 |
-
weight_source=resolved_sources.get(model_key, weight_source),
|
| 889 |
)
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 893 |
|
| 894 |
progress.progress(1.0, text="β
All models complete!")
|
| 895 |
|
|
|
|
| 307 |
return "base", f"{short_name} has no '{requested_source}' weights. Using base."
|
| 308 |
|
| 309 |
|
| 310 |
+
def _finetuned_available_for_model(model_name, requested_source):
|
| 311 |
+
if requested_source == "base":
|
| 312 |
+
return True
|
| 313 |
+
model_dir = MODEL_DIR.get(model_name)
|
| 314 |
+
if not model_dir or model_dir in DISABLE_FINETUNE_FOR:
|
| 315 |
+
return False
|
| 316 |
+
if _has_finetuned(model_dir, requested_source):
|
| 317 |
+
return True
|
| 318 |
+
_ensure_model_outputs_available(model_dir)
|
| 319 |
+
return _has_finetuned(model_dir, requested_source)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 323 |
# Cached Model Loaders (with weight_source support)
|
| 324 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 772 |
|
| 773 |
with col_result:
|
| 774 |
if uploaded_file and generate_btn:
|
| 775 |
+
if not _finetuned_available_for_model(selected_model, weight_source):
|
| 776 |
+
st.error(
|
| 777 |
+
f"{MODEL_SHORT[selected_model]} does not have '{weight_source}' weights."
|
| 778 |
+
)
|
| 779 |
+
caption = None
|
| 780 |
+
else:
|
| 781 |
+
with st.spinner(
|
| 782 |
+
f"Loading {MODEL_SHORT[selected_model]} ({weight_source}) + generatingβ¦"
|
| 783 |
+
):
|
| 784 |
+
try:
|
| 785 |
+
caption = generate_caption(
|
| 786 |
+
selected_model, selected_mode, image,
|
| 787 |
+
num_beams=num_beams,
|
| 788 |
+
max_new_tokens=max_new_tokens,
|
| 789 |
+
length_penalty=length_penalty,
|
| 790 |
+
weight_source=weight_source,
|
| 791 |
+
)
|
| 792 |
+
except Exception as e:
|
| 793 |
+
st.error(f"Generation error: {e}")
|
| 794 |
+
caption = None
|
| 795 |
|
| 796 |
if caption:
|
| 797 |
render_caption_card(
|
| 798 |
+
selected_model, caption, weight_source,
|
| 799 |
num_beams, length_penalty, max_new_tokens,
|
| 800 |
container=st,
|
| 801 |
)
|
|
|
|
| 857 |
compare_image = Image.open(compare_file).convert("RGB")
|
| 858 |
|
| 859 |
resolved_sources = {}
|
|
|
|
| 860 |
for model_key in MODEL_KEYS:
|
| 861 |
+
resolved_sources[model_key] = weight_source
|
| 862 |
+
if weight_source != "base":
|
| 863 |
+
missing = [
|
| 864 |
+
MODEL_SHORT[m]
|
| 865 |
+
for m in MODEL_KEYS
|
| 866 |
+
if not _finetuned_available_for_model(m, weight_source)
|
| 867 |
+
]
|
| 868 |
+
if missing:
|
| 869 |
+
st.warning(
|
| 870 |
+
"Missing fine-tuned weights for: "
|
| 871 |
+
+ ", ".join(missing)
|
| 872 |
+
+ ". Marking those results as unavailable."
|
| 873 |
+
)
|
| 874 |
|
| 875 |
# Generate captions from all 4 models
|
| 876 |
results = {}
|
|
|
|
| 896 |
else:
|
| 897 |
mode = "Baseline (Full Attention)"
|
| 898 |
|
| 899 |
+
if not _finetuned_available_for_model(model_key, weight_source):
|
| 900 |
+
results[model_key] = (
|
| 901 |
+
f"[Fine-tuned '{weight_source}' weights not available]"
|
| 902 |
+
if weight_source != "base"
|
| 903 |
+
else "[Not available]"
|
|
|
|
|
|
|
| 904 |
)
|
| 905 |
+
else:
|
| 906 |
+
try:
|
| 907 |
+
cap = generate_caption(
|
| 908 |
+
model_key, mode, compare_image,
|
| 909 |
+
num_beams=num_beams,
|
| 910 |
+
max_new_tokens=max_new_tokens,
|
| 911 |
+
length_penalty=length_penalty,
|
| 912 |
+
weight_source=weight_source,
|
| 913 |
+
)
|
| 914 |
+
results[model_key] = cap
|
| 915 |
+
except Exception as e:
|
| 916 |
+
results[model_key] = f"[Error: {e}]"
|
| 917 |
|
| 918 |
progress.progress(1.0, text="β
All models complete!")
|
| 919 |
|