File size: 3,112 Bytes
c56ecdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8f6f32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c56ecdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8f6f32
 
c56ecdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import streamlit as st
import torch
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration

# Use st.cache_resource to load the model and processor once.
# This saves time and memory when the app re-runs.
@st.cache_resource
def load_blip_model():
    """
    Loads the BLIP-2 model and processor from Hugging Face.
    
    Returns:
        tuple: The loaded processor and model.
    """
    # Use the appropriate BLIP-2 model. "Salesforce/blip2-opt-2.7b" is a good option.
    processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    
    # Check for CUDA availability
    if torch.cuda.is_available():
        # Load model with float16 for reduced memory usage on GPU
        model = Blip2ForConditionalGeneration.from_pretrained(
            "Salesforce/blip2-opt-2.7b", 
            device_map="auto",
            torch_dtype=torch.float16
        )
    else:
        # Load model with auto device mapping for CPU
        model = Blip2ForConditionalGeneration.from_pretrained(
            "Salesforce/blip2-opt-2.7b", 
            device_map="auto"
        )
    return processor, model

# Load the model and processor
processor, model = load_blip_model()

# Set up the Streamlit app layout and title
st.set_page_config(
    page_title="BLIP-2 Image Captioning",
    page_icon="📸",
    layout="centered"
)

st.title("📸 BLIP-2 Image Captioning")
st.markdown("### Generate captions for your images using a powerful vision-language model.")
st.markdown("---")

# File uploader widget for the user to upload an image
uploaded_file = st.file_uploader(
    "Upload an image", 
    type=["jpg", "jpeg", "png", "webp"],
    help="Drag and drop or click to upload your image."
)

if uploaded_file is not None:
    try:
        # Open the uploaded image
        image = Image.open(uploaded_file).convert('RGB')
        
        # Display the uploaded image
        st.image(image, caption="Uploaded Image", use_column_width=True, channels="RGB")
        
        # Create a button to generate the caption
        if st.button("Generate Caption"):
            with st.spinner("Generating caption..."):
                # Preprocess the image and generate input tensors
                # Ensure tensors are moved to the correct device
                inputs = processor(images=image, return_tensors="pt").to(model.device)
                
                # Generate a caption using the model
                outputs = model.generate(**inputs, max_length=50)
                
                # Decode the generated caption tokens to a string
                caption = processor.decode(outputs[0], skip_special_tokens=True)
                
            # Display the generated caption
            st.success("Caption generated!")
            st.markdown(f"### **Generated Caption:**")
            st.info(caption.capitalize())
            
    except Exception as e:
        st.error(f"An error occurred: {e}")
        st.markdown("Please try uploading a different image or check the model availability.")

else:
    st.info("Upload an image to get started!")