File size: 6,900 Bytes
b5afce3
 
 
b3d0543
b5afce3
 
ad912ef
b5afce3
 
b3d0543
b5afce3
 
 
 
8edab76
b5afce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3d0543
 
b5afce3
 
 
b3d0543
 
 
 
 
 
b5afce3
 
 
 
 
b3d0543
b5afce3
b3d0543
b5afce3
b3d0543
 
b5afce3
b3d0543
 
b5afce3
b3d0543
b5afce3
 
 
 
 
b3d0543
b5afce3
 
 
b3d0543
b5afce3
 
 
 
 
 
b3d0543
b5afce3
 
b3d0543
b5afce3
 
 
 
 
 
 
 
b3d0543
 
b5afce3
 
bd9009b
 
 
 
 
b3d0543
b889520
b3d0543
 
 
 
b889520
bd9009b
 
 
b3d0543
bd9009b
 
b3d0543
 
 
 
bd9009b
 
b5afce3
 
 
 
 
 
 
 
 
 
b3d0543
b5afce3
 
 
 
 
dada616
b5afce3
 
 
 
 
 
 
 
 
 
b3d0543
b5afce3
 
b3d0543
 
 
b5afce3
b3d0543
 
 
 
b5afce3
 
b3d0543
 
 
b5afce3
 
 
 
b3d0543
 
b5afce3
 
b3d0543
 
b5afce3
 
 
 
b3d0543
b5afce3
b3d0543
 
b5afce3
b3d0543
b5afce3
 
 
 
 
 
 
 
 
 
 
 
b3d0543
 
b5afce3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import streamlit as st
import torch
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import io
import time

# Set page config
st.set_page_config(
    page_title="๐Ÿš€ BLIP-2 Caption Generator",
    page_icon="๐Ÿš€",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS for better styling
st.markdown("""
<style>
    .main-header {
        text-align: center;
        padding: 2rem 0;
        background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
        color: white;
        border-radius: 10px;
        margin-bottom: 2rem;
    }
    .upload-section {
        border: 2px dashed #ccc;
        border-radius: 10px;
        padding: 2rem;
        text-align: center;
        margin: 1rem 0;
    }
    .caption-box {
        background-color: #f0f2f6;
        border-left: 4px solid #667eea;
        padding: 1rem;
        border-radius: 5px;
        margin: 1rem 0;
    }
</style>
""", unsafe_allow_html=True)

@st.cache_resource
def load_model():
    """Load and cache the BLIP-2 model and processor"""
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Use the smaller BLIP-2 model for better performance on Hugging Face Spaces
        model_name = "Salesforce/blip2-opt-2.7b"
        
        processor = Blip2Processor.from_pretrained(model_name)
        model = Blip2ForConditionalGeneration.from_pretrained(
            model_name, 
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            device_map="auto" if device == "cuda" else None
        )
        
        if device == "cpu":
            model = model.to(device)
        
        return processor, model, device
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return None, None, None

def generate_caption(image, processor, model, device, prompt=""):
    """Generate caption for the uploaded image"""
    try:
        # Prepare inputs
        if prompt:
            inputs = processor(image, text=prompt, return_tensors="pt").to(device)
        else:
            inputs = processor(image, return_tensors="pt").to(device)
        
        # Generate caption
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                max_length=50,
                num_beams=5,
                temperature=0.7,
                do_sample=True,
                early_stopping=True
            )
        
        # Decode the generated caption
        caption = processor.decode(generated_ids[0], skip_special_tokens=True)
        return caption
    
    except Exception as e:
        st.error(f"Error generating caption: {str(e)}")
        return None

def main():
    # Header
    st.markdown("""
    <div class="main-header">
        <h1>๐Ÿš€ BLIP-2 Caption Generator</h1>
        <p>Upload an image and get AI-generated captions instantly!</p>
    </div>
    """, unsafe_allow_html=True)
    
    # Sidebar
    with st.sidebar:
        st.header("๐Ÿ”ง Settings")
        st.markdown("### Model Information")
        st.info("Using **BLIP-2** (Salesforce/blip2-opt-2.7b)")
        
        # Custom prompt option
        custom_prompt = st.text_input(
            "Custom Prompt (Optional):",
            placeholder="e.g., 'Question: What is in this image? Answer:'"
        )
        
        st.markdown("### About")
        st.markdown("""
        This app uses the **BLIP-2** model to generate natural language descriptions of images.
        
        **Features:**
        - ๐Ÿ–ผ๏ธ Upload any image format
        - ๐Ÿค– AI-powered captioning  
        - โšก Fast inference
        - ๐ŸŽฏ Optional custom prompts
        """)
    
    # Main content
    col1, col2 = st.columns([1, 1])
    
    with col1:
        st.markdown("### ๐Ÿ“ค Upload Image")
        
        # File uploader
        uploaded_file = st.file_uploader(
            "Choose an image file",
            type=["jpg", "jpeg", "png", "bmp", "tiff"],
            help="Upload an image to generate a caption"
        )
        
        if uploaded_file is not None:
            # Display uploaded image
            image = Image.open(uploaded_file)
            st.image(image, caption="Uploaded Image", use_container_width=True)
            
            # Image info
            st.markdown(f"""
            **Image Info:**
            - Size: {image.size[0]} x {image.size[1]} pixels
            - Format: {image.format}
            - Mode: {image.mode}
            """)
    
    with col2:
        st.markdown("### ๐Ÿ”ฎ Generated Caption")
        
        if uploaded_file is not None:
            # Load model
            with st.spinner("Loading BLIP-2 model..."):
                processor, model, device = load_model()
            
            if processor is not None and model is not None:
                # Generate caption button
                if st.button("๐ŸŽฏ Generate Caption", type="primary"):
                    with st.spinner("Generating caption..."):
                        start_time = time.time()
                        
                        # Generate caption
                        caption = generate_caption(
                            image, processor, model, device, custom_prompt
                        )
                        
                        end_time = time.time()
                        
                        if caption:
                            # Display caption
                            st.markdown(f"""
                            <div class="caption-box">
                                <h4>๐Ÿ“ Caption:</h4>
                                <p style="font-size: 16px; font-weight: 500;">{caption}</p>
                            </div>
                            """, unsafe_allow_html=True)
                            
                            # Performance info
                            st.success(f"Caption generated in {end_time - start_time:.2f} seconds")
                            
                            # Copy to clipboard button
                            st.code(caption, language=None)
            else:
                st.error("Failed to load the model. Please try refreshing the page.")
        else:
            st.markdown("""
            <div class="upload-section">
                <h3>๐Ÿ‘† Upload an image to get started</h3>
                <p>Supported formats: JPG, PNG, BMP, TIFF</p>
            </div>
            """, unsafe_allow_html=True)
    
    # Footer
    st.markdown("---")
    st.markdown("""
    <div style="text-align: center; color: #666;">
        <p>Built with <strong>Streamlit</strong> and <strong>Hugging Face Transformers</strong></p>
        <p>Powered by <strong>BLIP-2</strong> - Bootstrapping Language-Image Pre-training</p>
    </div>
    """, unsafe_allow_html=True)

if __name__ == "__main__":
    main()