Imagecaption / app.py
HafizUsama's picture
Update app.py
507db78 verified
import streamlit as st
from PIL import Image
import io
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration
st.set_page_config(page_title="BLIP-2 Image Captioner", layout="centered")
@st.cache_resource
def load_model(model_name: str, device: str):
"""
Load BLIP-2 processor and model and return them.
"""
# Load processor + model
processor = Blip2Processor.from_pretrained(model_name)
# Try to use float16 when possible (faster & lower memory usage)
try:
# If device is cpu, we avoid float16 because it's not well-supported
dtype = torch.float16 if device.startswith("cuda") else torch.float32
model = Blip2ForConditionalGeneration.from_pretrained(model_name, torch_dtype=dtype)
except Exception as e:
# fallback to default dtype
st.warning(f"Model loaded with fallback dtype due to: {e}")
model = Blip2ForConditionalGeneration.from_pretrained(model_name)
model.to(device)
model.eval()
return processor, model
def generate_caption(processor, model, image: Image.Image, device: str, max_tokens: int = 64, num_beams: int = 3):
# prepare inputs
inputs = processor(images=image, return_tensors="pt").to(device)
# generate
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_tokens,
num_beams=num_beams,
early_stopping=True,
do_sample=False,
)
caption = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return caption
def main():
st.title("🖼️ BLIP-2 Image Captioning — Streamlit")
st.markdown(
"Upload an image and BLIP-2 will generate a short caption. This demo uses a Hugging Face BLIP-2 checkpoint."
)
# Sidebar controls
st.sidebar.header("Settings")
model_name = st.sidebar.selectbox(
"Model checkpoint",
(
"Salesforce/blip2-flan-t5-large",
"Salesforce/blip2-flan-t5-xl",
"Salesforce/blip2-opt-2.7b"
),
index=0,
help="Choose the BLIP-2 variant. xl/2.7b need more GPU memory; use large for CPU or small GPUs."
)
use_gpu = st.sidebar.checkbox("Use GPU if available", value=True)
max_tokens = st.sidebar.slider("Max new tokens", min_value=16, max_value=128, value=64, step=8)
num_beams = st.sidebar.slider("Beams (higher=better, slower)", min_value=1, max_value=6, value=3)
# Detect device
device = "cpu"
if use_gpu and torch.cuda.is_available():
device = f"cuda:{torch.cuda.current_device()}"
st.sidebar.write(f"Running on: **{device}**")
# Load model (cached)
with st.spinner("Loading model — first load can take a while..."):
processor, model = load_model(model_name, device)
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "webp"], accept_multiple_files=False)
if uploaded_file is not None:
try:
image = Image.open(io.BytesIO(uploaded_file.read())).convert("RGB")
except Exception as e:
st.error(f"Couldn't open image: {e}")
return
st.image(image, caption="Input image", use_column_width=True)
if st.button("Generate caption"):
with st.spinner("Generating caption..."):
try:
caption = generate_caption(processor, model, image, device=device, max_tokens=max_tokens, num_beams=num_beams)
st.success("Caption generated")
st.markdown(f"### ✨ Caption\n{caption}")
except Exception as e:
st.error(f"Error while generating caption: {e}")
else:
st.info("Upload an image to get started. You can also try one of the example images below.")
col1, col2, col3 = st.columns(3)
sample_urls = [
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/image_to_text.png",
"https://images.unsplash.com/photo-1503023345310-bd7c1de61c7d",
"https://images.unsplash.com/photo-1519681393784-d120267933ba",
]
for c, url in zip((col1, col2, col3), sample_urls):
if c.button(f"Use sample {sample_urls.index(url)+1}"):
try:
from urllib.request import urlopen
im = Image.open(urlopen(url)).convert("RGB")
st.image(im, use_column_width=True)
caption = generate_caption(processor, model, im, device=device, max_tokens=max_tokens, num_beams=num_beams)
st.markdown(f"### ✨ Caption\n{caption}")
except Exception as e:
st.error(f"Failed to load sample image: {e}")
if __name__ == "__main__":
main()