File size: 2,544 Bytes
47db847
 
 
 
 
705bdac
 
47db847
 
 
 
 
 
 
ece2cef
 
 
47db847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ece2cef
47db847
ece2cef
47db847
705bdac
47db847
 
 
 
 
 
 
 
 
705bdac
 
47db847
 
 
 
705bdac
 
 
 
 
ece2cef
705bdac
ece2cef
 
 
 
 
 
 
 
 
 
705bdac
 
ece2cef
705bdac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import streamlit as st
import torch
from PIL import Image
import numpy as np
from transformers import AutoProcessor, AutoModelForCausalLM
from io import BytesIO
import base64

# Initialize Florence model
device = "cuda" if torch.cuda.is_available() else "cpu"
florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)

def generate_caption(image):
    """Generate a caption for the given image using Florence 2"""
    # Convert image to RGB format to avoid channel errors
    image = image.convert("RGB")  

    # Prepare the input for the Florence model
    inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
    
    # Generate the caption using the model
    generated_ids = florence_model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        early_stopping=False,
        do_sample=False,
        num_beams=3,
    )
    
    # Decode the generated text
    generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    return generated_text

# Streamlit UI
st.title("Florence 2 Caption Generator")
st.write("Upload an image to generate a caption:")

# Image upload input
uploaded_image = st.file_uploader("Choose an Image", type=["jpg", "jpeg", "png"])

# If an image is uploaded
if uploaded_image is not None:
    image = Image.open(uploaded_image)
    st.image(image, caption="Uploaded Image", use_container_width=True)

    # Generate caption when button is pressed
    if st.button("Generate Caption"):
        caption = generate_caption(image)
        st.subheader("Generated Caption:")
        st.write(caption)

# ✅ API Mode: Handle API Requests
def handle_api_request():
    """Handle API request by checking URL query parameters."""
    query_params = st.query_params

    if "image" in query_params:
        try:
            image_base64 = query_params["image"]
            image_bytes = BytesIO(base64.b64decode(image_base64))
            image = Image.open(image_bytes).convert("RGB")  # Ensure it's RGB
            
            caption = generate_caption(image)
            st.json({"caption": caption})  # Return JSON response
        except Exception as e:
            st.json({"error": str(e)})

# Check if API mode is enabled
if "image" in st.query_params:
    handle_api_request()