Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
| # Use st.cache_resource to load the model and processor once. | |
| # This saves time and memory when the app re-runs. | |
| def load_blip_model(): | |
| """ | |
| Loads the BLIP-2 model and processor from Hugging Face. | |
| Returns: | |
| tuple: The loaded processor and model. | |
| """ | |
| # Use the appropriate BLIP-2 model. "Salesforce/blip2-opt-2.7b" is a good option. | |
| processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
| # Check for CUDA availability | |
| if torch.cuda.is_available(): | |
| # Load model with float16 for reduced memory usage on GPU | |
| model = Blip2ForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip2-opt-2.7b", | |
| device_map="auto", | |
| torch_dtype=torch.float16 | |
| ) | |
| else: | |
| # Load model with auto device mapping for CPU | |
| model = Blip2ForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip2-opt-2.7b", | |
| device_map="auto" | |
| ) | |
| return processor, model | |
| # Load the model and processor | |
| processor, model = load_blip_model() | |
| # Set up the Streamlit app layout and title | |
| st.set_page_config( | |
| page_title="BLIP-2 Image Captioning", | |
| page_icon="📸", | |
| layout="centered" | |
| ) | |
| st.title("📸 BLIP-2 Image Captioning") | |
| st.markdown("### Generate captions for your images using a powerful vision-language model.") | |
| st.markdown("---") | |
| # File uploader widget for the user to upload an image | |
| uploaded_file = st.file_uploader( | |
| "Upload an image", | |
| type=["jpg", "jpeg", "png", "webp"], | |
| help="Drag and drop or click to upload your image." | |
| ) | |
| if uploaded_file is not None: | |
| try: | |
| # Open the uploaded image | |
| image = Image.open(uploaded_file).convert('RGB') | |
| # Display the uploaded image | |
| st.image(image, caption="Uploaded Image", use_column_width=True, channels="RGB") | |
| # Create a button to generate the caption | |
| if st.button("Generate Caption"): | |
| with st.spinner("Generating caption..."): | |
| # Preprocess the image and generate input tensors | |
| # Ensure tensors are moved to the correct device | |
| inputs = processor(images=image, return_tensors="pt").to(model.device) | |
| # Generate a caption using the model | |
| outputs = model.generate(**inputs, max_length=50) | |
| # Decode the generated caption tokens to a string | |
| caption = processor.decode(outputs[0], skip_special_tokens=True) | |
| # Display the generated caption | |
| st.success("Caption generated!") | |
| st.markdown(f"### **Generated Caption:**") | |
| st.info(caption.capitalize()) | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| st.markdown("Please try uploading a different image or check the model availability.") | |
| else: | |
| st.info("Upload an image to get started!") | |