Darius Morawiec
Refactor model loading and processing
9afc0f5
raw
history blame
12.6 kB
import base64
import gc
import json
import os
from io import BytesIO
from pathlib import Path
import gradio as gr
import torch
from json_repair import repair_json
from qwen_vl_utils import process_vision_info
from transformers import (
AutoProcessor,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
Qwen3VLForConditionalGeneration,
)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func, duration: int = 60):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
EXAMPLES_DIR = Path(__file__).parent / "examples"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model_ids = [
"Qwen/Qwen2-VL-2B-Instruct", # https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct
"Qwen/Qwen2-VL-7B-Instruct", # https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
"Qwen/Qwen2.5-VL-3B-Instruct", # https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct
"Qwen/Qwen2.5-VL-7B-Instruct", # https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct
"Qwen/Qwen2.5-VL-32B-Instruct", # https://huggingface.co/Qwen/Qwen2.5-VL-32B-Instruct
"Qwen/Qwen2.5-VL-72B-Instruct", # https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct
"Qwen/Qwen3-VL-2B-Instruct", # https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct
"Qwen/Qwen3-VL-4B-Instruct", # https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct
"Qwen/Qwen3-VL-8B-Instruct", # https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct
"Qwen/Qwen3-VL-32B-Instruct", # https://huggingface.co/Qwen/Qwen3-VL-32B-Instruct
]
def scale_image(image, target_size=1000):
width, height = image.size
if max(width, height) <= target_size:
return image
if width >= height:
new_width = target_size
new_height = int((target_size / width) * height)
else:
new_height = target_size
new_width = int((target_size / height) * width)
return image.resize((new_width, new_height))
def image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
with gr.Blocks() as demo:
gr.Markdown("# Qwen-VL Object-Detection")
gr.Markdown(
"Compare [Qwen3-VL](https://huggingface.co/collections/Qwen/qwen3-vl), [Qwen2.5-VL](https://huggingface.co/collections/Qwen/qwen25-vl) and [Qwen2-VL](https://huggingface.co/collections/Qwen/qwen2-vl) models by [Qwen](https://huggingface.co/Qwen) for object detection."
)
with gr.Row():
with gr.Column():
gr.Markdown("## Inputs")
image_input = gr.Image(
label="Input Image",
type="pil",
)
gr.Markdown("## Settings")
input_model_id = gr.Dropdown(
choices=model_ids,
label="✨ Select Model ID",
)
default_system_prompt = 'You are a helpful assistant to detect objects in images. When asked to detect elements based on a description, you return a valid JSON object containing bounding boxes for all elements in the form `[{"bbox_2d": [xmin, ymin, xmax, ymax], "label": "placeholder"}, ...]`. For example, a valid response could be: `[{"bbox_2d": [10, 30, 20, 60], "label": "placeholder"}, {"bbox_2d": [40, 15, 52, 27], "label": "placeholder"}]`.'
system_prompt = gr.Textbox(
label="System Prompt",
lines=3,
value=default_system_prompt,
)
default_user_prompt = "detect object"
user_prompt = gr.Textbox(
label="User Prompt",
lines=3,
value=default_user_prompt,
)
max_new_tokens = gr.Slider(
label="Max New Tokens",
minimum=32,
maximum=4096,
value=256,
step=32,
interactive=True,
)
image_target_size = gr.Slider(
label="Image Target Size",
minimum=256,
maximum=4096,
value=1024,
step=1,
interactive=True,
)
with gr.Column():
gr.Markdown("## Outputs")
output_annotated_image = gr.AnnotatedImage(
format="jpeg",
key="output_annotated_image",
label="Output Image",
)
gr.Markdown("## Detections")
output_text = gr.Textbox(
label="Output Text",
lines=10,
key="output_text",
)
with gr.Row():
run_button = gr.Button("Run")
# Global variables to track loaded model
current_model = None
current_processor = None
current_model_id = None
def load_model(model_id: str):
global current_model, current_processor, current_model_id
# Only load model if it's different from the currently loaded one
if current_model_id != model_id or current_model is None:
# Clear previous model from memory
if current_model is not None:
del current_model
current_model = None
if current_processor is not None:
del current_processor
current_processor = None
# Force garbage collection and clear CUDA cache
gc.collect()
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.synchronize()
# Load new model
model_loader = None
if model_id.startswith("Qwen/Qwen2-VL"):
model_loader = Qwen2VLForConditionalGeneration
elif model_id.startswith("Qwen/Qwen2.5-VL"):
model_loader = Qwen2_5_VLForConditionalGeneration
elif model_id.startswith("Qwen/Qwen3-VL"):
model_loader = Qwen3VLForConditionalGeneration
assert model_loader is not None, f"Unsupported model ID: {model_id}"
current_model = model_loader.from_pretrained(
model_id,
torch_dtype="auto",
device_map="auto",
).eval()
current_processor = AutoProcessor.from_pretrained(model_id)
current_model_id = model_id
return current_model, current_processor
def run(
image,
model_id: str,
system_prompt: str,
user_prompt: str,
max_new_tokens: int = 1024,
image_target_size: int | None = None,
):
model, processor = load_model(model_id)
base64_image = image_to_base64(
scale_image(image, image_target_size) if image_target_size else image
)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": f"data:image;base64,{base64_image}",
},
{"type": "text", "text": system_prompt},
{"type": "text", "text": user_prompt},
],
}
]
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(DEVICE)
@spaces.GPU(duration=300)
def _generate(**kwargs):
return model.generate(**kwargs)
generated_ids = _generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
output_text = str(output_text[0])
output_text = repair_json(output_text)
output_json = json.loads(output_text)
scale = False if model_id.startswith("Qwen/Qwen2.5-VL") else True
x_scale = float(image.width / 1000) if scale else 1.0
y_scale = float(image.height / 1000) if scale else 1.0
bboxes = []
for detection in output_json:
if "bbox_2d" not in detection:
continue
if len(detection["bbox_2d"]) != 4:
continue
if "label" not in detection:
continue
xmin, ymin, xmax, ymax = detection["bbox_2d"]
label = detection.get("label", "")
bbox = [
int(xmin * x_scale),
int(ymin * y_scale),
int(xmax * x_scale),
int(ymax * y_scale),
]
bboxes.append((bbox, label))
return [(image, bboxes), str(json.dumps(output_json))]
with gr.Row():
with gr.Column():
gr.Markdown("## Examples")
gr.Examples(
fn=run,
cache_examples=True,
cache_mode="eager",
run_on_click=True,
examples=[
[
EXAMPLES_DIR
/ "niklas-ohlrogge-niamoh-de-fDYRfHoRC4k-unsplash.jpg",
"Qwen/Qwen3-VL-4B-Instruct",
default_system_prompt,
"detect sailboat, rowboat, person",
512,
1920,
],
[
EXAMPLES_DIR / "elevate-nYgy58eb9aw-unsplash.jpg",
"Qwen/Qwen3-VL-4B-Instruct",
default_system_prompt,
"detect shirt, jeans, jacket, skirt, sunglasses, earring, drink",
1024,
1920,
],
[
EXAMPLES_DIR / "markus-spiske-oPDQGXW7i40-unsplash.jpg",
"Qwen/Qwen3-VL-4B-Instruct",
default_system_prompt,
"detect basketball, player with white jersey, player with black jersey",
512,
1920,
],
[
EXAMPLES_DIR / "william-hook-9e9PD9blAto-unsplash.jpg",
"Qwen/Qwen3-VL-4B-Instruct",
default_system_prompt,
"detect app to find great places, app to take beautiful photos, app to listen music",
512,
1920,
],
[
EXAMPLES_DIR / "tasso-mitsarakis-dw7Y4W6Rhmk-unsplash.jpg",
"Qwen/Qwen3-VL-4B-Instruct",
default_system_prompt,
"detect person, bicycle, netherlands flag",
1920,
1920,
],
],
inputs=[
image_input,
input_model_id,
system_prompt,
user_prompt,
max_new_tokens,
image_target_size,
],
outputs=[
output_annotated_image,
output_text,
],
)
if DEVICE != "cuda":
gr.Markdown(
"👉 It's recommended to run this application on a machine with a CUDA-compatible GPU for optimal performance. You can clone this space locally or duplicate this space with a CUDA-enabled runtime."
)
# Connect the button to the detection function
run_button.click(
fn=run,
inputs=[
image_input,
input_model_id,
system_prompt,
user_prompt,
max_new_tokens,
image_target_size,
],
outputs=[
output_annotated_image,
output_text,
],
)
if __name__ == "__main__":
demo.launch(
# share=True,
)