File size: 6,857 Bytes
7db9fcc 3a77e9a 7db9fcc 3a77e9a e583785 3a77e9a 2774e90 3a77e9a 2774e90 3a77e9a 4f3e371 3a77e9a a5b1435 7db9fcc a5b1435 2a4dd1c a5b1435 7db9fcc a5b1435 d2e8362 7db9fcc ad6e6ce d2e8362 7db9fcc 44742be 7db9fcc d2e8362 a5b1435 3a77e9a 7db9fcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import os
import subprocess
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
import gradio as gr
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
from typing import Iterable
# --- Theme and CSS Definition ---
colors.steel_blue = colors.Color(
name="steel_blue",
c50="#EBF3F8",
c100="#D3E5F0",
c200="#A8CCE1",
c300="#7DB3D2",
c400="#529AC3",
c500="#4682B4", # SteelBlue base color
c600="#3E72A0",
c700="#36638C",
c800="#2E5378",
c900="#264364",
c950="#1E3450",
)
class SteelBlueTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.steel_blue,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
body_background_fill="linear-gradient(135deg, *primary_100, *primary_200)",
body_background_fill_dark="linear-gradient(135deg, *primary_800, *primary_900)",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_text_color="white",
slider_color="*secondary_500",
slider_color_dark="*secondary_600",
block_title_text_weight="600",
block_border_width="2px",
block_shadow="*shadow_drop_lg",
button_shadow="*shadow_drop_lg",
button_large_padding="12px",
)
# Instantiate the theme
steel_blue_theme = SteelBlueTheme()
# --- Model and App Setup ---
# Attempt to install flash-attn
try:
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True)
except subprocess.CalledProcessError as e:
print(f"Error installing flash-attn: {e}")
print("Continuing without flash-attn.")
# Determine the device to use
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the base model and processor
try:
vision_language_model_base = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
vision_language_processor_base = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
except Exception as e:
print(f"Error loading base model: {e}")
vision_language_model_base = None
vision_language_processor_base = None
# Load the large model and processor
try:
vision_language_model_large = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to(device).eval()
vision_language_processor_large = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True)
except Exception as e:
print(f"Error loading large model: {e}")
vision_language_model_large = None
vision_language_processor_large = None
def describe_image(uploaded_image, model_choice):
"""
Generates a detailed description of the input image using the selected model.
"""
if uploaded_image is None:
return "Please upload an image."
if model_choice == "Florence-2-base":
if vision_language_model_base is None:
return "Base model failed to load."
model = vision_language_model_base
processor = vision_language_processor_base
elif model_choice == "Florence-2-large":
if vision_language_model_large is None:
return "Large model failed to load."
model = vision_language_model_large
processor = vision_language_processor_large
else:
return "Invalid model choice."
if not isinstance(uploaded_image, Image.Image):
uploaded_image = Image.fromarray(uploaded_image)
inputs = processor(text="<MORE_DETAILED_CAPTION>", images=uploaded_image, return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
processed_description = processor.post_process_generation(
generated_text,
task="<MORE_DETAILED_CAPTION>",
image_size=(uploaded_image.width, uploaded_image.height)
)
image_description = processed_description["<MORE_DETAILED_CAPTION>"]
print("\nImage description generated!:", image_description)
return image_description
# Description for the interface
description = "> Select the model to use for generating the image description. 'Base' is smaller and faster, while 'Large' is more accurate but slower."
if device == "cpu":
description += " Note: Running on CPU, which may be slow for large models."
# Define examples
examples = [
["images/2.jpeg", "Florence-2-large"],
["images/1.jpeg", "Florence-2-base"],
["images/3.jpeg", "Florence-2-large"],
["images/4.jpeg", "Florence-2-large"]
]
# Create the Gradio interface with Blocks
with gr.Blocks(theme=steel_blue_theme) as demo:
gr.Markdown("# **Florence-2 Models Image Captions**")
gr.Markdown(description)
with gr.Row():
# Left column: Input image and Generate button
with gr.Column(scale=2):
image_input = gr.Image(label="Upload Image", type="pil", height=400)
generate_btn = gr.Button("Generate Caption", variant="primary")
gr.Examples(examples=examples, inputs=[image_input])
# Right column: Model choice, output, and examples
with gr.Column(scale=3):
model_choice = gr.Radio(["Florence-2-base", "Florence-2-large"], label="Model Choice", value="Florence-2-base")
output = gr.Textbox(label="Generated Caption", lines=10, show_copy_button=True)
# Connect the button to the function
generate_btn.click(fn=describe_image, inputs=[image_input, model_choice], outputs=output)
# Launch the interface
demo.launch(debug=True, mcp_server=True, ssr_mode=False, show_error=True) |