|
|
import streamlit as st |
|
|
import torch |
|
|
from transformers import Blip2Processor, Blip2ForConditionalGeneration |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="BLIP-2 Image Captioning", |
|
|
page_icon="📸", |
|
|
layout="wide", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
""" |
|
|
Loads the BLIP-2 model and processor from Hugging Face Hub. |
|
|
|
|
|
We're using a smaller version (`blip2-opt-2.7b`) that is more suitable for |
|
|
Hugging Face's free tier, though it may still require significant resources. |
|
|
We load the model in 8-bit to reduce memory usage. |
|
|
""" |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
try: |
|
|
|
|
|
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") |
|
|
|
|
|
|
|
|
|
|
|
model = Blip2ForConditionalGeneration.from_pretrained( |
|
|
"Salesforce/blip2-opt-2.7b", |
|
|
device_map="auto", |
|
|
load_in_8bit=True, |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
return processor, model, device |
|
|
except Exception as e: |
|
|
st.error(f"Error loading the model: {e}") |
|
|
st.info("The model is very large and may require a GPU with at least 15GB of VRAM. " |
|
|
"If you're seeing this error, the free tier of Hugging Face Spaces might not be enough.") |
|
|
return None, None, None |
|
|
|
|
|
|
|
|
st.title("📸 BLIP-2 Image Captioning AI") |
|
|
st.write( |
|
|
"Upload an image, and this application will generate a descriptive caption using the powerful " |
|
|
"[BLIP-2 model](https://huggingface.co/Salesforce/blip2-opt-2.7b) from Hugging Face." |
|
|
) |
|
|
|
|
|
|
|
|
processor, model, device = load_model() |
|
|
|
|
|
if model and processor: |
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
|
"Choose an image...", |
|
|
type=["jpg", "jpeg", "png", "bmp"], |
|
|
help="Upload an image file to get a caption." |
|
|
) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
|
|
|
st.image(uploaded_file, caption="Uploaded Image", use_column_width=True) |
|
|
st.write("") |
|
|
st.info("Generating caption...") |
|
|
|
|
|
try: |
|
|
|
|
|
raw_image = Image.open(uploaded_file).convert("RGB") |
|
|
|
|
|
|
|
|
inputs = processor(images=raw_image, return_tensors="pt").to(device, torch.float16) |
|
|
out = model.generate(**inputs, max_new_tokens=50) |
|
|
|
|
|
|
|
|
caption = processor.decode(out[0], skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
|
st.success(f"**Caption:** {caption}") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"An error occurred during caption generation: {e}") |
|
|
|
|
|
else: |
|
|
st.warning("The application could not be initialized. Please check the logs for details.") |
|
|
|