File size: 4,875 Bytes
b45d08e
 
 
 
 
 
 
 
 
 
507db78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b45d08e
 
 
507db78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import streamlit as st
from PIL import Image
import io
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration

st.set_page_config(page_title="BLIP-2 Image Captioner", layout="centered")

@st.cache_resource
def load_model(model_name: str, device: str):
    """
    Load BLIP-2 processor and model and return them.
    """
    # Load processor + model
    processor = Blip2Processor.from_pretrained(model_name)
    # Try to use float16 when possible (faster & lower memory usage)
    try:
        # If device is cpu, we avoid float16 because it's not well-supported
        dtype = torch.float16 if device.startswith("cuda") else torch.float32
        model = Blip2ForConditionalGeneration.from_pretrained(model_name, torch_dtype=dtype)
    except Exception as e:
        # fallback to default dtype
        st.warning(f"Model loaded with fallback dtype due to: {e}")
        model = Blip2ForConditionalGeneration.from_pretrained(model_name)

    model.to(device)
    model.eval()
    return processor, model


def generate_caption(processor, model, image: Image.Image, device: str, max_tokens: int = 64, num_beams: int = 3):
    # prepare inputs
    inputs = processor(images=image, return_tensors="pt").to(device)

    # generate
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            num_beams=num_beams,
            early_stopping=True,
            do_sample=False,
        )

    caption = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
    return caption


def main():
    st.title("🖼️ BLIP-2 Image Captioning — Streamlit")
    st.markdown(
        "Upload an image and BLIP-2 will generate a short caption. This demo uses a Hugging Face BLIP-2 checkpoint."
    )

    # Sidebar controls
    st.sidebar.header("Settings")

    model_name = st.sidebar.selectbox(
        "Model checkpoint",
        (
            "Salesforce/blip2-flan-t5-large",
            "Salesforce/blip2-flan-t5-xl",
            "Salesforce/blip2-opt-2.7b"
        ),
        index=0,
        help="Choose the BLIP-2 variant. xl/2.7b need more GPU memory; use large for CPU or small GPUs."
    )

    use_gpu = st.sidebar.checkbox("Use GPU if available", value=True)
    max_tokens = st.sidebar.slider("Max new tokens", min_value=16, max_value=128, value=64, step=8)
    num_beams = st.sidebar.slider("Beams (higher=better, slower)", min_value=1, max_value=6, value=3)

    # Detect device
    device = "cpu"
    if use_gpu and torch.cuda.is_available():
        device = f"cuda:{torch.cuda.current_device()}"
    st.sidebar.write(f"Running on: **{device}**")

    # Load model (cached)
    with st.spinner("Loading model — first load can take a while..."):
        processor, model = load_model(model_name, device)

    uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "webp"], accept_multiple_files=False)

    if uploaded_file is not None:
        try:
            image = Image.open(io.BytesIO(uploaded_file.read())).convert("RGB")
        except Exception as e:
            st.error(f"Couldn't open image: {e}")
            return

        st.image(image, caption="Input image", use_column_width=True)

        if st.button("Generate caption"):
            with st.spinner("Generating caption..."):
                try:
                    caption = generate_caption(processor, model, image, device=device, max_tokens=max_tokens, num_beams=num_beams)
                    st.success("Caption generated")
                    st.markdown(f"### ✨ Caption\n{caption}")
                except Exception as e:
                    st.error(f"Error while generating caption: {e}")

    else:
        st.info("Upload an image to get started. You can also try one of the example images below.")
        col1, col2, col3 = st.columns(3)
        sample_urls = [
            "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/image_to_text.png",
            "https://images.unsplash.com/photo-1503023345310-bd7c1de61c7d",
            "https://images.unsplash.com/photo-1519681393784-d120267933ba",
        ]
        for c, url in zip((col1, col2, col3), sample_urls):
            if c.button(f"Use sample {sample_urls.index(url)+1}"):
                try:
                    from urllib.request import urlopen
                    im = Image.open(urlopen(url)).convert("RGB")
                    st.image(im, use_column_width=True)
                    caption = generate_caption(processor, model, im, device=device, max_tokens=max_tokens, num_beams=num_beams)
                    st.markdown(f"### ✨ Caption\n{caption}")
                except Exception as e:
                    st.error(f"Failed to load sample image: {e}")


if __name__ == "__main__":
    main()