Afsha001's picture
update name
eb27bae verified
import os
import gc
import torch
import numpy as np
import pandas as pd
import requests
import base64
import streamlit as st
import plotly.graph_objects as go
from PIL import Image
from io import BytesIO
from collections import Counter
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize
st.set_page_config(
page_title="Image Captioning Refinement Fusion System",
layout="wide",
initial_sidebar_state="expanded"
)
JINA_KEY = os.environ.get("JINA_KEY", "")
JINA_URL = "https://api.jina.ai/v1/rerank"
JINA_HEADERS = {
"Authorization": f"Bearer {JINA_KEY}",
"Content-Type": "application/json"
}
if not JINA_KEY:
st.error("JINA_KEY missing. Go to Space Settings β†’ Secrets and add it.")
st.stop()
@st.cache_resource
def load_local_models():
from transformers import (
AutoProcessor,
AutoModelForCausalLM,
AutoTokenizer,
BlipProcessor,
BlipForImageTextRetrieval
)
gc.collect()
florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-large",
trust_remote_code=True
)
florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large",
trust_remote_code=True,
torch_dtype=torch.float32
)
florence_model.eval()
blip_processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-large"
)
blip_itm_model = BlipForImageTextRetrieval.from_pretrained(
"Salesforce/blip-itm-large-coco",
torch_dtype=torch.float32
)
blip_itm_model.eval()
qwen_tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2.5-1.5B-Instruct"
)
qwen_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-1.5B-Instruct",
torch_dtype=torch.float32
)
qwen_model.eval()
return (
florence_processor, florence_model,
blip_processor, blip_itm_model,
qwen_tokenizer, qwen_model
)
def image_to_bytes(image: Image.Image) -> bytes:
buf = BytesIO()
image.save(buf, format="JPEG", quality=85)
return buf.getvalue()
def image_to_data_uri(image: Image.Image) -> str:
raw = image_to_bytes(image)
b64 = base64.b64encode(raw).decode()
return f"data:image/jpeg;base64,{b64}"
# ============================================================================
# STEP 1 β€” FLORENCE-2-LARGE: 5 DISTINCT CAPTION APPROACHES
#
# Cap 1: <CAPTION> greedy
# β†’ single concise sentence, primary subject only
# Cap 2: <CAPTION> sampling temp=1.0
# β†’ alt-text accessibility style, concise but different phrasing
# Cap 3: <DETAILED_CAPTION> temp=0.7
# β†’ paragraph describing foreground, background, colors
# Cap 4: <DETAILED_CAPTION> temp=1.1
# β†’ focuses on mood, atmosphere, implied action
# Cap 5: <MORE_DETAILED_CAPTION> temp=0.8
# β†’ exhaustive breakdown of every visible element
# ============================================================================
def generate_captions_florence(image: Image.Image, florence_proc, florence_mod) -> list:
captions = []
image_size = (image.width, image.height)
tasks = [
("<CAPTION>", 30, {"num_beams": 1}),
("<CAPTION>", 35, {"do_sample": True, "temperature": 1.0, "top_p": 0.92}),
("<DETAILED_CAPTION>", 80, {"do_sample": True, "temperature": 0.7, "top_p": 0.90}),
("<DETAILED_CAPTION>", 90, {"do_sample": True, "temperature": 1.1, "top_p": 0.95}),
("<MORE_DETAILED_CAPTION>", 120, {"do_sample": True, "temperature": 0.8, "top_p": 0.92}),
]
for task_prompt, max_tokens, gen_params in tasks:
try:
inputs = florence_proc(
text=task_prompt, images=image, return_tensors="pt"
)
with torch.no_grad():
ids = florence_mod.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=max_tokens,
**gen_params
)
raw = florence_proc.batch_decode(ids, skip_special_tokens=False)[0]
parsed = florence_proc.post_process_generation(
raw, task=task_prompt, image_size=image_size
)
cap = parsed.get(task_prompt, "").strip().lower()
captions.append(cap if cap else "a scene shown in the image")
except Exception as e:
st.warning(f"Florence {task_prompt} error: {str(e)[:80]}")
captions.append("a scene shown in the image")
seen, unique = set(), []
for c in captions:
if c not in seen:
seen.add(c)
unique.append(c)
if len(unique) < 2:
unique = captions
while len(unique) < 5:
unique.append(unique[0])
return unique[:5]
def compute_itm_scores(image, captions, blip_proc, blip_itm) -> list:
scores = []
for cap in captions:
try:
inputs = blip_proc(
images=image, text=cap,
return_tensors="pt", padding=True
)
with torch.no_grad():
out = blip_itm(**inputs)
score = torch.nn.functional.softmax(
out.itm_score, dim=1
)[:, 1].item()
scores.append(round(float(score), 4))
except Exception as e:
st.warning(f"ITM error: {str(e)[:60]}")
scores.append(0.0)
return scores
def compute_jina_scores(image: Image.Image, captions: list) -> list:
img_data_uri = image_to_data_uri(image)
scores = []
for cap in captions:
try:
payload = {
"model": "jina-reranker-m0",
"query": cap,
"documents": [img_data_uri],
"top_n": 1
}
response = requests.post(
JINA_URL, headers=JINA_HEADERS,
json=payload, timeout=30
)
if response.status_code == 200:
result = response.json()
if "results" in result and result["results"]:
score = result["results"][0].get("relevance_score", 0.0)
scores.append(round(float(score), 4))
else:
scores.append(0.0)
else:
st.warning(f"Jina API error {response.status_code}: {response.text[:100]}")
scores.append(0.0)
except Exception as e:
st.warning(f"Jina exception: {str(e)[:60]}")
scores.append(0.0)
return scores
def compute_cosine_scores(image, captions, blip_proc, blip_itm) -> list:
try:
img_inp = blip_proc(images=image, return_tensors="pt")
with torch.no_grad():
vis = blip_itm.vision_model(pixel_values=img_inp["pixel_values"])
img_feat = blip_itm.vision_proj(vis.last_hidden_state[:, 0, :]).numpy()
img_feat = normalize(img_feat, norm="l2")
cap_inp = blip_proc(
text=captions, return_tensors="pt",
padding=True, truncation=True, max_length=512
)
with torch.no_grad():
txt = blip_itm.text_encoder(
input_ids=cap_inp["input_ids"],
attention_mask=cap_inp["attention_mask"]
)
cap_feat = blip_itm.text_proj(txt.last_hidden_state[:, 0, :]).numpy()
cap_feat = normalize(cap_feat, norm="l2")
sims = cosine_similarity(img_feat, cap_feat)[0]
return [round(float(s), 4) for s in sims]
except Exception as e:
st.warning(f"Cosine error: {str(e)[:60]}")
return [0.0] * len(captions)
def majority_voting(captions, itm, jina, cosine) -> tuple:
itm_r = np.argsort(itm)[::-1]
jina_r = np.argsort(jina)[::-1]
cosine_r = np.argsort(cosine)[::-1]
votes = [
int(itm_r[0]), int(itm_r[1]),
int(jina_r[0]), int(jina_r[1]),
int(cosine_r[0]), int(cosine_r[1])
]
counts = Counter(votes)
top2 = [idx for idx, _ in counts.most_common(2)]
if len(top2) < 2:
top2 = [int(itm_r[0]), int(jina_r[0])]
return captions[top2[0]], captions[top2[1]], top2, dict(counts)
def fuse_captions(cap1: str, cap2: str, qwen_tok, qwen_mod) -> str:
system_prompt = (
"You write image captions. "
"You will receive two captions of the same image. "
"Your job is to combine them into one detailed caption. "
"Include ALL specific details you find: "
"the clothing colors and style of each person, "
"what each person looks like and what they are doing, "
"the objects and surroundings visible around them, "
"and the setting or background of the scene. "
"Write 5 to 6 sentences. Use simple, clear, everyday words. "
"Do NOT summarize or shorten β€” keep every specific detail. "
"Only include what is clearly visible. "
"Return ONLY the caption, nothing else."
)
user_prompt = (
f"Caption A: {cap1}\n"
f"Caption B: {cap2}\n\n"
"Write a detailed caption that includes all the clothing, "
"people, objects and background in details:"
)
try:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
text = qwen_tok.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
model_inputs = qwen_tok([text], return_tensors="pt")
with torch.no_grad():
generated_ids = qwen_mod.generate(
**model_inputs,
max_new_tokens=120,
do_sample=False
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]
fused = qwen_tok.decode(output_ids, skip_special_tokens=True).strip()
for prefix in ["Caption:", "Result:", "Answer:", "Fused caption:"]:
if fused.lower().startswith(prefix.lower()):
fused = fused[len(prefix):].strip()
return fused if fused else cap1
except Exception as e:
st.warning(f"Qwen fusion error: {str(e)[:80]}")
return cap1
def compute_caption_quality(image, final_caption, blip_proc, blip_itm) -> tuple:
try:
inputs = blip_proc(
images=image, text=final_caption,
return_tensors="pt", padding=True
)
with torch.no_grad():
out = blip_itm(**inputs)
itm_score = torch.nn.functional.softmax(
out.itm_score, dim=1
)[:, 1].item()
except:
itm_score = 0.0
try:
img_inp = blip_proc(images=image, return_tensors="pt")
with torch.no_grad():
vis = blip_itm.vision_model(pixel_values=img_inp["pixel_values"])
img_feat = blip_itm.vision_proj(vis.last_hidden_state[:, 0, :]).numpy()
img_feat = normalize(img_feat, norm="l2")
cap_inp = blip_proc(
text=[final_caption], return_tensors="pt",
padding=True, truncation=True, max_length=512
)
with torch.no_grad():
txt = blip_itm.text_encoder(
input_ids=cap_inp["input_ids"],
attention_mask=cap_inp["attention_mask"]
)
cap_feat = blip_itm.text_proj(txt.last_hidden_state[:, 0, :]).numpy()
cap_feat = normalize(cap_feat, norm="l2")
cosine_score = float(cosine_similarity(img_feat, cap_feat)[0][0])
except:
cosine_score = 0.0
avg_score = round((itm_score + cosine_score) / 2, 4)
return avg_score, round(itm_score, 4), round(cosine_score, 4)
# ============================================================================
# GAUGE β€” updated to match reference style
# Bright saturated zone colors, sharp black needle, clean arc, no dark shades
# ============================================================================
def render_gauge(score, itm, cosine, placeholder):
if score >= 0.75:
label, bar_color = "Good", "#16a34a"
elif score >= 0.50:
label, bar_color = "Moderate", "#d97706"
elif score >= 0.25:
label, bar_color = "Low", "#ca8a04"
else:
label, bar_color = "Poor", "#dc2626"
fig = go.Figure(go.Indicator(
mode = "gauge+number",
value = score,
number = {
"font": {"size": 36, "color": bar_color, "family": "Arial Black"},
"suffix": ""
},
gauge = {
"axis": {
"range": [0, 1],
"tickwidth": 2,
"tickcolor": "#111827",
"tickfont": {"size": 11, "color": "#374151"}
},
"bar": {
"color": "#111827",
"thickness": 0.06
},
"bgcolor": "white",
"borderwidth": 0,
"steps": [
{"range": [0.00, 0.25], "color": "#ef4444"},
{"range": [0.25, 0.50], "color": "#f59e0b"},
{"range": [0.50, 0.75], "color": "#84cc16"},
{"range": [0.75, 1.00], "color": "#22c55e"},
],
"threshold": {
"line": {"color": "#111827", "width": 5},
"thickness": 0.85,
"value": score
}
},
title = {
"text": f"Caption Quality Score<br><b style='color:{bar_color};font-size:15px'>{label}</b>",
"font": {"size": 13, "color": "#374151"}
}
))
fig.update_layout(
height = 240,
margin = dict(l=15, r=15, t=55, b=5),
paper_bgcolor = "rgba(0,0,0,0)",
plot_bgcolor = "rgba(0,0,0,0)",
font = {"color": "#374151", "family": "Arial"}
)
with placeholder:
st.markdown("<br>", unsafe_allow_html=True)
g_col, s_col = st.columns([3, 2])
with g_col:
st.plotly_chart(fig, use_container_width=True)
with s_col:
st.markdown("<br><br>", unsafe_allow_html=True)
st.markdown("**Score Breakdown**")
st.markdown(f"Image-Text Match: **{itm}**")
st.markdown(f"Embedding Similarity: **{cosine}**")
st.markdown(f"Overall Score: **{score} / 1.00**")
st.markdown(
f"<span style='background:{bar_color};color:white;"
f"padding:4px 12px;border-radius:12px;"
f"font-weight:700;font-size:13px;'>{label}</span>",
unsafe_allow_html=True
)
# ============================================================================
# SIDEBAR β€” pipeline steps + live accuracy section (session_state)
# ============================================================================
with st.sidebar:
st.title("Image Captioning Refinement Fusion")
st.markdown("---")
st.markdown("### Pipeline Steps")
st.markdown("""
**1. Florence-2-Large** (Local)
Generate 5 captions
**2. BLIP ITM** (Local)
Image-text matching
**3. Jina Reranker M0** (API)
Semantic reranking
**4. Cosine Similarity** (Local)
Embedding similarity
**5. Majority Voting**
Best 2 captions selected
**6. Qwen2.5-1.5B** (Local)
Caption fusion
""")
st.markdown("---")
st.markdown("**Local:** Florence-2, BLIP ITM, Qwen2.5")
st.markdown("**API:** Jina")
# ── accuracy section ──────────────
st.markdown("---")
st.markdown("### Caption Quality Metrics")
st.markdown("""
**BLIP ITM** (Image-Text Match)
Measures how well the caption
matches the image content.
**Cosine Similarity**
Measures embedding distance
between image and caption.
""")
# ============================================================================
# MAIN UI
# ============================================================================
st.title("Image Captioning Refinement Fusion System")
st.markdown("Upload an image to generate a refined, grounded caption.")
st.markdown("---")
uploaded_file = st.file_uploader(
"Select an image",
type=["jpg", "jpeg", "png"]
)
if uploaded_file is not None:
input_image = Image.open(uploaded_file).convert("RGB")
col_img, col_run = st.columns([1, 1])
with col_img:
st.image(input_image, caption="Uploaded Image", use_container_width=True)
gauge_placeholder = st.empty()
with col_run:
if st.button("Generate Caption", type="primary", use_container_width=True):
with st.spinner("Loading local models (first run takes 3-4 min)..."):
(
florence_proc, florence_mod,
blip_proc, blip_itm,
qwen_tok, qwen_mod
) = load_local_models()
progress = st.progress(0)
status = st.empty()
status.info("Step 1/6: Generating captions with Florence-2-Large...")
captions = generate_captions_florence(input_image, florence_proc, florence_mod)
progress.progress(16)
with st.expander("5 Generated Captions", expanded=True):
for i, cap in enumerate(captions):
st.write(f"**{i+1}.** {cap}")
status.info("Step 2/6: Computing BLIP ITM scores...")
itm_scores = compute_itm_scores(input_image, captions, blip_proc, blip_itm)
progress.progress(32)
status.info("Step 3/6: Computing Jina Reranker scores...")
jina_scores = compute_jina_scores(input_image, captions)
progress.progress(50)
status.info("Step 4/6: Computing Cosine Similarity scores...")
cosine_scores = compute_cosine_scores(input_image, captions, blip_proc, blip_itm)
progress.progress(66)
scores_df = pd.DataFrame({
"Caption": [f"Cap {i+1}: {c[:50]}" for i, c in enumerate(captions)],
"ITM": itm_scores,
"Jina": jina_scores,
"Cosine": cosine_scores
})
with st.expander("All Scores", expanded=False):
st.dataframe(scores_df, use_container_width=True, hide_index=True)
status.info("Step 5/6: Running majority voting...")
best_1, best_2, _, _ = majority_voting(
captions, itm_scores, jina_scores, cosine_scores
)
progress.progress(83)
st.markdown("### Majority Voted Captions")
c1, c2 = st.columns(2)
with c1:
st.success(f"1. {best_1}")
with c2:
st.info(f"2. {best_2}")
status.info("Step 6/6: Fusing captions with Qwen2.5-1.5B...")
final = fuse_captions(best_1, best_2, qwen_tok, qwen_mod)
progress.progress(100)
status.success("Pipeline complete!")
st.markdown("---")
st.markdown("### Final Fused Caption")
st.markdown(
f"<div style='"
f"background:linear-gradient(135deg,#667eea,#764ba2);"
f"padding:24px;border-radius:12px;color:white;"
f"font-size:18px;font-weight:500;text-align:center;"
f"line-height:1.6;'>{final}</div>",
unsafe_allow_html=True
)
avg_score, itm_q, cosine_q = compute_caption_quality(
input_image, final, blip_proc, blip_itm
)
# Store in session_state so sidebar updates on rerender
st.session_state.avg_score = avg_score
st.session_state.itm_q = itm_q
st.session_state.cosine_q = cosine_q
render_gauge(avg_score, itm_q, cosine_q, gauge_placeholder)