Captionimage / app.py
haris018's picture
Update app.py
b8f6f32 verified
import streamlit as st
import torch
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
# Use st.cache_resource to load the model and processor once.
# This saves time and memory when the app re-runs.
@st.cache_resource
def load_blip_model():
"""
Loads the BLIP-2 model and processor from Hugging Face.
Returns:
tuple: The loaded processor and model.
"""
# Use the appropriate BLIP-2 model. "Salesforce/blip2-opt-2.7b" is a good option.
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
# Check for CUDA availability
if torch.cuda.is_available():
# Load model with float16 for reduced memory usage on GPU
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
device_map="auto",
torch_dtype=torch.float16
)
else:
# Load model with auto device mapping for CPU
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
device_map="auto"
)
return processor, model
# Load the model and processor
processor, model = load_blip_model()
# Set up the Streamlit app layout and title
st.set_page_config(
page_title="BLIP-2 Image Captioning",
page_icon="📸",
layout="centered"
)
st.title("📸 BLIP-2 Image Captioning")
st.markdown("### Generate captions for your images using a powerful vision-language model.")
st.markdown("---")
# File uploader widget for the user to upload an image
uploaded_file = st.file_uploader(
"Upload an image",
type=["jpg", "jpeg", "png", "webp"],
help="Drag and drop or click to upload your image."
)
if uploaded_file is not None:
try:
# Open the uploaded image
image = Image.open(uploaded_file).convert('RGB')
# Display the uploaded image
st.image(image, caption="Uploaded Image", use_column_width=True, channels="RGB")
# Create a button to generate the caption
if st.button("Generate Caption"):
with st.spinner("Generating caption..."):
# Preprocess the image and generate input tensors
# Ensure tensors are moved to the correct device
inputs = processor(images=image, return_tensors="pt").to(model.device)
# Generate a caption using the model
outputs = model.generate(**inputs, max_length=50)
# Decode the generated caption tokens to a string
caption = processor.decode(outputs[0], skip_special_tokens=True)
# Display the generated caption
st.success("Caption generated!")
st.markdown(f"### **Generated Caption:**")
st.info(caption.capitalize())
except Exception as e:
st.error(f"An error occurred: {e}")
st.markdown("Please try uploading a different image or check the model availability.")
else:
st.info("Upload an image to get started!")