Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from PIL import Image | |
| import io | |
| import torch | |
| from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
| st.set_page_config(page_title="BLIP-2 Image Captioner", layout="centered") | |
| def load_model(model_name: str, device: str): | |
| """ | |
| Load BLIP-2 processor and model and return them. | |
| """ | |
| # Load processor + model | |
| processor = Blip2Processor.from_pretrained(model_name) | |
| # Try to use float16 when possible (faster & lower memory usage) | |
| try: | |
| # If device is cpu, we avoid float16 because it's not well-supported | |
| dtype = torch.float16 if device.startswith("cuda") else torch.float32 | |
| model = Blip2ForConditionalGeneration.from_pretrained(model_name, torch_dtype=dtype) | |
| except Exception as e: | |
| # fallback to default dtype | |
| st.warning(f"Model loaded with fallback dtype due to: {e}") | |
| model = Blip2ForConditionalGeneration.from_pretrained(model_name) | |
| model.to(device) | |
| model.eval() | |
| return processor, model | |
| def generate_caption(processor, model, image: Image.Image, device: str, max_tokens: int = 64, num_beams: int = 3): | |
| # prepare inputs | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| # generate | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| num_beams=num_beams, | |
| early_stopping=True, | |
| do_sample=False, | |
| ) | |
| caption = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
| return caption | |
| def main(): | |
| st.title("🖼️ BLIP-2 Image Captioning — Streamlit") | |
| st.markdown( | |
| "Upload an image and BLIP-2 will generate a short caption. This demo uses a Hugging Face BLIP-2 checkpoint." | |
| ) | |
| # Sidebar controls | |
| st.sidebar.header("Settings") | |
| model_name = st.sidebar.selectbox( | |
| "Model checkpoint", | |
| ( | |
| "Salesforce/blip2-flan-t5-large", | |
| "Salesforce/blip2-flan-t5-xl", | |
| "Salesforce/blip2-opt-2.7b" | |
| ), | |
| index=0, | |
| help="Choose the BLIP-2 variant. xl/2.7b need more GPU memory; use large for CPU or small GPUs." | |
| ) | |
| use_gpu = st.sidebar.checkbox("Use GPU if available", value=True) | |
| max_tokens = st.sidebar.slider("Max new tokens", min_value=16, max_value=128, value=64, step=8) | |
| num_beams = st.sidebar.slider("Beams (higher=better, slower)", min_value=1, max_value=6, value=3) | |
| # Detect device | |
| device = "cpu" | |
| if use_gpu and torch.cuda.is_available(): | |
| device = f"cuda:{torch.cuda.current_device()}" | |
| st.sidebar.write(f"Running on: **{device}**") | |
| # Load model (cached) | |
| with st.spinner("Loading model — first load can take a while..."): | |
| processor, model = load_model(model_name, device) | |
| uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "webp"], accept_multiple_files=False) | |
| if uploaded_file is not None: | |
| try: | |
| image = Image.open(io.BytesIO(uploaded_file.read())).convert("RGB") | |
| except Exception as e: | |
| st.error(f"Couldn't open image: {e}") | |
| return | |
| st.image(image, caption="Input image", use_column_width=True) | |
| if st.button("Generate caption"): | |
| with st.spinner("Generating caption..."): | |
| try: | |
| caption = generate_caption(processor, model, image, device=device, max_tokens=max_tokens, num_beams=num_beams) | |
| st.success("Caption generated") | |
| st.markdown(f"### ✨ Caption\n{caption}") | |
| except Exception as e: | |
| st.error(f"Error while generating caption: {e}") | |
| else: | |
| st.info("Upload an image to get started. You can also try one of the example images below.") | |
| col1, col2, col3 = st.columns(3) | |
| sample_urls = [ | |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/image_to_text.png", | |
| "https://images.unsplash.com/photo-1503023345310-bd7c1de61c7d", | |
| "https://images.unsplash.com/photo-1519681393784-d120267933ba", | |
| ] | |
| for c, url in zip((col1, col2, col3), sample_urls): | |
| if c.button(f"Use sample {sample_urls.index(url)+1}"): | |
| try: | |
| from urllib.request import urlopen | |
| im = Image.open(urlopen(url)).convert("RGB") | |
| st.image(im, use_column_width=True) | |
| caption = generate_caption(processor, model, im, device=device, max_tokens=max_tokens, num_beams=num_beams) | |
| st.markdown(f"### ✨ Caption\n{caption}") | |
| except Exception as e: | |
| st.error(f"Failed to load sample image: {e}") | |
| if __name__ == "__main__": | |
| main() | |