import streamlit as st import torch from PIL import Image from transformers import Blip2Processor, Blip2ForConditionalGeneration # Use st.cache_resource to load the model and processor once. # This saves time and memory when the app re-runs. @st.cache_resource def load_blip_model(): """ Loads the BLIP-2 model and processor from Hugging Face. Returns: tuple: The loaded processor and model. """ # Use the appropriate BLIP-2 model. "Salesforce/blip2-opt-2.7b" is a good option. processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") # Check for CUDA availability if torch.cuda.is_available(): # Load model with float16 for reduced memory usage on GPU model = Blip2ForConditionalGeneration.from_pretrained( "Salesforce/blip2-opt-2.7b", device_map="auto", torch_dtype=torch.float16 ) else: # Load model with auto device mapping for CPU model = Blip2ForConditionalGeneration.from_pretrained( "Salesforce/blip2-opt-2.7b", device_map="auto" ) return processor, model # Load the model and processor processor, model = load_blip_model() # Set up the Streamlit app layout and title st.set_page_config( page_title="BLIP-2 Image Captioning", page_icon="📸", layout="centered" ) st.title("📸 BLIP-2 Image Captioning") st.markdown("### Generate captions for your images using a powerful vision-language model.") st.markdown("---") # File uploader widget for the user to upload an image uploaded_file = st.file_uploader( "Upload an image", type=["jpg", "jpeg", "png", "webp"], help="Drag and drop or click to upload your image." ) if uploaded_file is not None: try: # Open the uploaded image image = Image.open(uploaded_file).convert('RGB') # Display the uploaded image st.image(image, caption="Uploaded Image", use_column_width=True, channels="RGB") # Create a button to generate the caption if st.button("Generate Caption"): with st.spinner("Generating caption..."): # Preprocess the image and generate input tensors # Ensure tensors are moved to the correct device inputs = processor(images=image, return_tensors="pt").to(model.device) # Generate a caption using the model outputs = model.generate(**inputs, max_length=50) # Decode the generated caption tokens to a string caption = processor.decode(outputs[0], skip_special_tokens=True) # Display the generated caption st.success("Caption generated!") st.markdown(f"### **Generated Caption:**") st.info(caption.capitalize()) except Exception as e: st.error(f"An error occurred: {e}") st.markdown("Please try uploading a different image or check the model availability.") else: st.info("Upload an image to get started!")