import streamlit as st import torch from PIL import Image import numpy as np from transformers import AutoProcessor, AutoModelForCausalLM from io import BytesIO import base64 # Initialize Florence model device = "cuda" if torch.cuda.is_available() else "cpu" florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval() florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True) def generate_caption(image): """Generate a caption for the given image using Florence 2""" # Convert image to RGB format to avoid channel errors image = image.convert("RGB") # Prepare the input for the Florence model inputs = florence_processor(text="", images=image, return_tensors="pt").to(device) # Generate the caption using the model generated_ids = florence_model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3, ) # Decode the generated text generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text # Streamlit UI st.title("Florence 2 Caption Generator") st.write("Upload an image to generate a caption:") # Image upload input uploaded_image = st.file_uploader("Choose an Image", type=["jpg", "jpeg", "png"]) # If an image is uploaded if uploaded_image is not None: image = Image.open(uploaded_image) st.image(image, caption="Uploaded Image", use_container_width=True) # Generate caption when button is pressed if st.button("Generate Caption"): caption = generate_caption(image) st.subheader("Generated Caption:") st.write(caption) # ✅ API Mode: Handle API Requests def handle_api_request(): """Handle API request by checking URL query parameters.""" query_params = st.query_params if "image" in query_params: try: image_base64 = query_params["image"] image_bytes = BytesIO(base64.b64decode(image_base64)) image = Image.open(image_bytes).convert("RGB") # Ensure it's RGB caption = generate_caption(image) st.json({"caption": caption}) # Return JSON response except Exception as e: st.json({"error": str(e)}) # Check if API mode is enabled if "image" in st.query_params: handle_api_request()