|
|
import os |
|
|
import subprocess |
|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
import gradio as gr |
|
|
from gradio.themes import Soft |
|
|
from gradio.themes.utils import colors, fonts, sizes |
|
|
from typing import Iterable |
|
|
|
|
|
|
|
|
|
|
|
colors.steel_blue = colors.Color( |
|
|
name="steel_blue", |
|
|
c50="#EBF3F8", |
|
|
c100="#D3E5F0", |
|
|
c200="#A8CCE1", |
|
|
c300="#7DB3D2", |
|
|
c400="#529AC3", |
|
|
c500="#4682B4", |
|
|
c600="#3E72A0", |
|
|
c700="#36638C", |
|
|
c800="#2E5378", |
|
|
c900="#264364", |
|
|
c950="#1E3450", |
|
|
) |
|
|
|
|
|
class SteelBlueTheme(Soft): |
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
primary_hue: colors.Color | str = colors.gray, |
|
|
secondary_hue: colors.Color | str = colors.steel_blue, |
|
|
neutral_hue: colors.Color | str = colors.slate, |
|
|
text_size: sizes.Size | str = sizes.text_lg, |
|
|
font: fonts.Font | str | Iterable[fonts.Font | str] = ( |
|
|
fonts.GoogleFont("Outfit"), "Arial", "sans-serif", |
|
|
), |
|
|
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( |
|
|
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", |
|
|
), |
|
|
): |
|
|
super().__init__( |
|
|
primary_hue=primary_hue, |
|
|
secondary_hue=secondary_hue, |
|
|
neutral_hue=neutral_hue, |
|
|
text_size=text_size, |
|
|
font=font, |
|
|
font_mono=font_mono, |
|
|
) |
|
|
super().set( |
|
|
body_background_fill="linear-gradient(135deg, *primary_100, *primary_200)", |
|
|
body_background_fill_dark="linear-gradient(135deg, *primary_800, *primary_900)", |
|
|
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", |
|
|
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", |
|
|
button_primary_text_color="white", |
|
|
slider_color="*secondary_500", |
|
|
slider_color_dark="*secondary_600", |
|
|
block_title_text_weight="600", |
|
|
block_border_width="2px", |
|
|
block_shadow="*shadow_drop_lg", |
|
|
button_shadow="*shadow_drop_lg", |
|
|
button_large_padding="12px", |
|
|
) |
|
|
|
|
|
|
|
|
steel_blue_theme = SteelBlueTheme() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True) |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"Error installing flash-attn: {e}") |
|
|
print("Continuing without flash-attn.") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
try: |
|
|
vision_language_model_base = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval() |
|
|
vision_language_processor_base = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True) |
|
|
except Exception as e: |
|
|
print(f"Error loading base model: {e}") |
|
|
vision_language_model_base = None |
|
|
vision_language_processor_base = None |
|
|
|
|
|
|
|
|
try: |
|
|
vision_language_model_large = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to(device).eval() |
|
|
vision_language_processor_large = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True) |
|
|
except Exception as e: |
|
|
print(f"Error loading large model: {e}") |
|
|
vision_language_model_large = None |
|
|
vision_language_processor_large = None |
|
|
|
|
|
def describe_image(uploaded_image, model_choice): |
|
|
""" |
|
|
Generates a detailed description of the input image using the selected model. |
|
|
""" |
|
|
if uploaded_image is None: |
|
|
return "Please upload an image." |
|
|
|
|
|
if model_choice == "Florence-2-base": |
|
|
if vision_language_model_base is None: |
|
|
return "Base model failed to load." |
|
|
model = vision_language_model_base |
|
|
processor = vision_language_processor_base |
|
|
elif model_choice == "Florence-2-large": |
|
|
if vision_language_model_large is None: |
|
|
return "Large model failed to load." |
|
|
model = vision_language_model_large |
|
|
processor = vision_language_processor_large |
|
|
else: |
|
|
return "Invalid model choice." |
|
|
|
|
|
if not isinstance(uploaded_image, Image.Image): |
|
|
uploaded_image = Image.fromarray(uploaded_image) |
|
|
|
|
|
inputs = processor(text="<MORE_DETAILED_CAPTION>", images=uploaded_image, return_tensors="pt").to(device) |
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate( |
|
|
input_ids=inputs["input_ids"], |
|
|
pixel_values=inputs["pixel_values"], |
|
|
max_new_tokens=1024, |
|
|
early_stopping=False, |
|
|
do_sample=False, |
|
|
num_beams=3, |
|
|
) |
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
|
|
processed_description = processor.post_process_generation( |
|
|
generated_text, |
|
|
task="<MORE_DETAILED_CAPTION>", |
|
|
image_size=(uploaded_image.width, uploaded_image.height) |
|
|
) |
|
|
image_description = processed_description["<MORE_DETAILED_CAPTION>"] |
|
|
print("\nImage description generated!:", image_description) |
|
|
return image_description |
|
|
|
|
|
|
|
|
description = "> Select the model to use for generating the image description. 'Base' is smaller and faster, while 'Large' is more accurate but slower." |
|
|
if device == "cpu": |
|
|
description += " Note: Running on CPU, which may be slow for large models." |
|
|
|
|
|
|
|
|
examples = [ |
|
|
["images/2.jpeg", "Florence-2-large"], |
|
|
["images/1.jpeg", "Florence-2-base"], |
|
|
["images/3.jpeg", "Florence-2-large"], |
|
|
["images/4.jpeg", "Florence-2-large"] |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=steel_blue_theme) as demo: |
|
|
gr.Markdown("# **Florence-2 Models Image Captions**") |
|
|
gr.Markdown(description) |
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
image_input = gr.Image(label="Upload Image", type="pil", height=400) |
|
|
generate_btn = gr.Button("Generate Caption", variant="primary") |
|
|
gr.Examples(examples=examples, inputs=[image_input]) |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
model_choice = gr.Radio(["Florence-2-base", "Florence-2-large"], label="Model Choice", value="Florence-2-base") |
|
|
output = gr.Textbox(label="Generated Caption", lines=10, show_copy_button=True) |
|
|
|
|
|
|
|
|
generate_btn.click(fn=describe_image, inputs=[image_input, model_choice], outputs=output) |
|
|
|
|
|
|
|
|
demo.launch(debug=True, mcp_server=True, ssr_mode=False, show_error=True) |