|
|
|
|
|
""" |
|
|
Streamlit BLIP-2 Image Captioning demo |
|
|
- Uses HuggingFace transformers' Blip2Processor + Blip2ForConditionalGeneration |
|
|
- Caches the model & processor with st.cache_resource so they load once per Space/session. |
|
|
- Designed for deployment on Hugging Face Spaces (use Docker SDK / Streamlit template). |
|
|
""" |
|
|
|
|
|
import streamlit as st |
|
|
from PIL import Image |
|
|
import io |
|
|
import torch |
|
|
from transformers import Blip2Processor, Blip2ForConditionalGeneration |
|
|
|
|
|
st.set_page_config( |
|
|
page_title="BLIP-2 Image Captioning", |
|
|
layout="wide", |
|
|
initial_sidebar_state="expanded", |
|
|
) |
|
|
|
|
|
|
|
|
st.sidebar.title("BLIP-2 Caption Demo") |
|
|
st.sidebar.markdown( |
|
|
""" |
|
|
Upload an image and BLIP-2 will generate a caption. |
|
|
- Model choices: choose a BLIP-2 model (large models may need GPU / won’t fit on CPU). |
|
|
- For Spaces deployment, prefer smaller/flan-xl variants or use inference API. |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
DEFAULT_MODEL = "Salesforce/blip2-opt-2.7b" |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def load_model_and_processor(model_name: str): |
|
|
"""Load and cache the BLIP-2 processor and model.""" |
|
|
|
|
|
processor = Blip2Processor.from_pretrained(model_name) |
|
|
model = Blip2ForConditionalGeneration.from_pretrained(model_name) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
return processor, model, device |
|
|
|
|
|
def generate_caption(processor, model, device, pil_image: Image.Image, max_new_tokens=50, num_beams=4): |
|
|
"""Generate caption text for a PIL image using BLIP-2.""" |
|
|
if pil_image.mode != "RGB": |
|
|
pil_image = pil_image.convert("RGB") |
|
|
|
|
|
inputs = processor(images=pil_image, return_tensors="pt").to(device) |
|
|
|
|
|
|
|
|
generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, num_beams=num_beams) |
|
|
|
|
|
caption = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
return caption |
|
|
|
|
|
|
|
|
col1, col2 = st.columns([1, 1.2]) |
|
|
|
|
|
with col1: |
|
|
st.header("Upload image") |
|
|
uploaded = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg"], accept_multiple_files=False) |
|
|
|
|
|
st.markdown("**Model selection**") |
|
|
model_name = st.selectbox( |
|
|
"Pick BLIP-2 model (large models may not run on CPU)", |
|
|
options=[ |
|
|
"Salesforce/blip2-flan-t5-xl", |
|
|
"Salesforce/blip2-opt-2.7b", |
|
|
"Salesforce/blip2-flan-t5-xxl", |
|
|
], |
|
|
index=1 if DEFAULT_MODEL.endswith("2.7b") else 0, |
|
|
help="Large models require GPU or HF Inference API; choose smaller if you have no GPU.", |
|
|
) |
|
|
|
|
|
max_tokens = st.slider("Max caption length (tokens)", min_value=10, max_value=200, value=50) |
|
|
num_beams = st.slider("Beam search width (num_beams)", min_value=1, max_value=8, value=4) |
|
|
|
|
|
st.write("---") |
|
|
st.markdown("Tips:") |
|
|
st.markdown( |
|
|
"- If deploying on CPU-only Spaces, use a smaller/flan model or use the Hugging Face Inference API.\n" |
|
|
"- Model loading is cached to speed up subsequent requests." |
|
|
) |
|
|
|
|
|
with col2: |
|
|
st.header("Preview & Caption") |
|
|
if uploaded is None: |
|
|
st.info("Upload an image on the left to generate a caption.") |
|
|
st.empty() |
|
|
else: |
|
|
|
|
|
image_bytes = uploaded.read() |
|
|
pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
st.image(pil_image, use_column_width=True) |
|
|
|
|
|
|
|
|
with st.spinner("Loading model (cached after first load)..."): |
|
|
processor, model, device = load_model_and_processor(model_name) |
|
|
|
|
|
|
|
|
if st.button("Generate caption"): |
|
|
with st.spinner("Generating caption..."): |
|
|
try: |
|
|
caption = generate_caption(processor, model, device, pil_image, max_new_tokens=max_tokens, num_beams=num_beams) |
|
|
st.success("Caption generated") |
|
|
st.markdown(f"**Caption:** {caption}") |
|
|
|
|
|
st.download_button("Download caption (.txt)", caption, file_name="caption.txt") |
|
|
except Exception as e: |
|
|
st.error(f"Error during generation: {e}") |
|
|
st.info("If model is too large or out-of-memory, try a smaller model or use GPU.") |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown( |
|
|
"Built with BLIP-2 + Transformers. For production or public Spaces hosting, consider using Hugging Face Inference API or a smaller model variant to avoid OOM on CPU-only hosts." |
|
|
) |
|
|
st.caption("Docs: BLIP-2 (Transformers), Hugging Face Spaces (Streamlit), Streamlit caching & uploader.") |
|
|
|