import streamlit as st from PIL import Image import io import torch from transformers import Blip2Processor, Blip2ForConditionalGeneration st.set_page_config(page_title="BLIP-2 Image Captioner", layout="centered") @st.cache_resource def load_model(model_name: str, device: str): """ Load BLIP-2 processor and model and return them. """ # Load processor + model processor = Blip2Processor.from_pretrained(model_name) # Try to use float16 when possible (faster & lower memory usage) try: # If device is cpu, we avoid float16 because it's not well-supported dtype = torch.float16 if device.startswith("cuda") else torch.float32 model = Blip2ForConditionalGeneration.from_pretrained(model_name, torch_dtype=dtype) except Exception as e: # fallback to default dtype st.warning(f"Model loaded with fallback dtype due to: {e}") model = Blip2ForConditionalGeneration.from_pretrained(model_name) model.to(device) model.eval() return processor, model def generate_caption(processor, model, image: Image.Image, device: str, max_tokens: int = 64, num_beams: int = 3): # prepare inputs inputs = processor(images=image, return_tensors="pt").to(device) # generate with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=max_tokens, num_beams=num_beams, early_stopping=True, do_sample=False, ) caption = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() return caption def main(): st.title("🖼️ BLIP-2 Image Captioning — Streamlit") st.markdown( "Upload an image and BLIP-2 will generate a short caption. This demo uses a Hugging Face BLIP-2 checkpoint." ) # Sidebar controls st.sidebar.header("Settings") model_name = st.sidebar.selectbox( "Model checkpoint", ( "Salesforce/blip2-flan-t5-large", "Salesforce/blip2-flan-t5-xl", "Salesforce/blip2-opt-2.7b" ), index=0, help="Choose the BLIP-2 variant. xl/2.7b need more GPU memory; use large for CPU or small GPUs." ) use_gpu = st.sidebar.checkbox("Use GPU if available", value=True) max_tokens = st.sidebar.slider("Max new tokens", min_value=16, max_value=128, value=64, step=8) num_beams = st.sidebar.slider("Beams (higher=better, slower)", min_value=1, max_value=6, value=3) # Detect device device = "cpu" if use_gpu and torch.cuda.is_available(): device = f"cuda:{torch.cuda.current_device()}" st.sidebar.write(f"Running on: **{device}**") # Load model (cached) with st.spinner("Loading model — first load can take a while..."): processor, model = load_model(model_name, device) uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "webp"], accept_multiple_files=False) if uploaded_file is not None: try: image = Image.open(io.BytesIO(uploaded_file.read())).convert("RGB") except Exception as e: st.error(f"Couldn't open image: {e}") return st.image(image, caption="Input image", use_column_width=True) if st.button("Generate caption"): with st.spinner("Generating caption..."): try: caption = generate_caption(processor, model, image, device=device, max_tokens=max_tokens, num_beams=num_beams) st.success("Caption generated") st.markdown(f"### ✨ Caption\n{caption}") except Exception as e: st.error(f"Error while generating caption: {e}") else: st.info("Upload an image to get started. You can also try one of the example images below.") col1, col2, col3 = st.columns(3) sample_urls = [ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/image_to_text.png", "https://images.unsplash.com/photo-1503023345310-bd7c1de61c7d", "https://images.unsplash.com/photo-1519681393784-d120267933ba", ] for c, url in zip((col1, col2, col3), sample_urls): if c.button(f"Use sample {sample_urls.index(url)+1}"): try: from urllib.request import urlopen im = Image.open(urlopen(url)).convert("RGB") st.image(im, use_column_width=True) caption = generate_caption(processor, model, im, device=device, max_tokens=max_tokens, num_beams=num_beams) st.markdown(f"### ✨ Caption\n{caption}") except Exception as e: st.error(f"Failed to load sample image: {e}") if __name__ == "__main__": main()