File size: 2,544 Bytes
47db847 705bdac 47db847 ece2cef 47db847 ece2cef 47db847 ece2cef 47db847 705bdac 47db847 705bdac 47db847 705bdac ece2cef 705bdac ece2cef 705bdac ece2cef 705bdac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | import streamlit as st
import torch
from PIL import Image
import numpy as np
from transformers import AutoProcessor, AutoModelForCausalLM
from io import BytesIO
import base64
# Initialize Florence model
device = "cuda" if torch.cuda.is_available() else "cpu"
florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
def generate_caption(image):
"""Generate a caption for the given image using Florence 2"""
# Convert image to RGB format to avoid channel errors
image = image.convert("RGB")
# Prepare the input for the Florence model
inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
# Generate the caption using the model
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
# Decode the generated text
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
# Streamlit UI
st.title("Florence 2 Caption Generator")
st.write("Upload an image to generate a caption:")
# Image upload input
uploaded_image = st.file_uploader("Choose an Image", type=["jpg", "jpeg", "png"])
# If an image is uploaded
if uploaded_image is not None:
image = Image.open(uploaded_image)
st.image(image, caption="Uploaded Image", use_container_width=True)
# Generate caption when button is pressed
if st.button("Generate Caption"):
caption = generate_caption(image)
st.subheader("Generated Caption:")
st.write(caption)
# ✅ API Mode: Handle API Requests
def handle_api_request():
"""Handle API request by checking URL query parameters."""
query_params = st.query_params
if "image" in query_params:
try:
image_base64 = query_params["image"]
image_bytes = BytesIO(base64.b64decode(image_base64))
image = Image.open(image_bytes).convert("RGB") # Ensure it's RGB
caption = generate_caption(image)
st.json({"caption": caption}) # Return JSON response
except Exception as e:
st.json({"error": str(e)})
# Check if API mode is enabled
if "image" in st.query_params:
handle_api_request() |