Darius Morawiec
Change emoji
c8d2dc0
raw
history blame
8.51 kB
import base64
import gc
import json
import os
from io import BytesIO
import gradio as gr
import torch
from json_repair import repair_json
from qwen_vl_utils import process_vision_info
from transformers import (
AutoProcessor,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
Qwen3VLForConditionalGeneration,
)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model_ids = [
"Qwen/Qwen2-VL-2B-Instruct", # https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct
"Qwen/Qwen2-VL-7B-Instruct", # https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
"Qwen/Qwen2.5-VL-3B-Instruct", # https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct
"Qwen/Qwen2.5-VL-7B-Instruct", # https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct
"Qwen/Qwen2.5-VL-32B-Instruct", # https://huggingface.co/Qwen/Qwen2.5-VL-32B-Instruct
"Qwen/Qwen2.5-VL-72B-Instruct", # https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct
"Qwen/Qwen3-VL-2B-Instruct", # https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct
"Qwen/Qwen3-VL-4B-Instruct", # https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct
"Qwen/Qwen3-VL-8B-Instruct", # https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct
"Qwen/Qwen3-VL-32B-Instruct", # https://huggingface.co/Qwen/Qwen3-VL-32B-Instruct
]
def image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
with gr.Blocks() as demo:
gr.Markdown("# Qwen-VL Object-Detection")
gr.Markdown(
"Compare [Qwen3-VL](https://huggingface.co/collections/Qwen/qwen3-vl), [Qwen2.5-VL](https://huggingface.co/collections/Qwen/qwen25-vl) and [Qwen2-VL](https://huggingface.co/collections/Qwen/qwen2-vl) models by [Qwen](https://huggingface.co/Qwen) for object detection in images."
)
if DEVICE != "cuda":
gr.Markdown(
"👉 It's recommended to run this application on a machine with a CUDA-compatible GPU for optimal performance. You can clone this space locally or duplicate this space with a CUDA-enabled runtime."
)
with gr.Row():
with gr.Column():
gr.Markdown("## Inputs")
image_input = gr.Image(
label="Input Image",
type="pil",
)
gr.Markdown("## Settings")
input_model_id = gr.Dropdown(
choices=model_ids,
label="Select Model ID",
)
default_system_prompt = 'You are a helpful assistant to detect objects in images. When asked to detect elements based on a description, you return a valid JSON object containing bounding boxes for all elements in the form `[{"bbox_2d": [xmin, ymin, xmax, ymax], "label": "placeholder"}, ...]`. For example, a valid response could be: `[{"bbox_2d": [10, 30, 20, 60], "label": "placeholder"}, {"bbox_2d": [40, 15, 52, 27], "label": "placeholder"}]`.'
system_prompt = gr.Textbox(
label="System Prompt:",
lines=3,
value=default_system_prompt,
)
default_user_prompt = "detect object"
user_prompt = gr.Textbox(
label="User Prompt:",
lines=3,
value=default_user_prompt,
)
max_new_tokens = gr.Slider(
label="Max New Tokens:",
minimum=32,
maximum=4096,
value=256,
step=32,
interactive=True,
)
with gr.Column():
gr.Markdown("## Outputs")
output_annotated_image = gr.AnnotatedImage(
format="jpeg",
key="output_annotated_image",
label="Output Image",
)
gr.Markdown("## Detections")
output_text = gr.Textbox(
label="Output Text",
lines=3,
key="output_text",
)
with gr.Row():
run_button = gr.Button("Run")
# Global variables to track loaded model
current_model = None
current_processor = None
current_model_id = None
def run(
image,
system_prompt: str,
user_prompt: str,
model_id: str,
max_new_tokens: int = 1024,
):
global current_model, current_processor, current_model_id
scale = False if model_id.startswith("Qwen/Qwen2.5-VL") else True
# Only load model if it's different from the currently loaded one
if current_model_id != model_id or current_model is None:
# Clear previous model from memory
if current_model is not None:
del current_model
current_model = None
if current_processor is not None:
del current_processor
current_processor = None
# Force garbage collection and clear CUDA cache
gc.collect()
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.synchronize()
# Load new model
if model_id.startswith("Qwen/Qwen2-VL"):
model_loader = Qwen2VLForConditionalGeneration
elif model_id.startswith("Qwen/Qwen2.5-VL"):
model_loader = Qwen2_5_VLForConditionalGeneration
elif model_id.startswith("Qwen/Qwen3-VL"):
model_loader = Qwen3VLForConditionalGeneration
current_model = model_loader.from_pretrained(
model_id,
torch_dtype="auto",
device_map="auto",
).eval()
current_processor = AutoProcessor.from_pretrained(model_id)
current_model_id = model_id
model = current_model
processor = current_processor
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": f"data:image;base64,{image_to_base64(image)}",
},
{"type": "text", "text": system_prompt},
{"type": "text", "text": user_prompt},
],
}
]
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(DEVICE)
generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
output_text = str(output_text[0])
output_text = repair_json(output_text)
output_json = json.loads(output_text)
x_scale = float(image.width / 1000) if scale else 1.0
y_scale = float(image.height / 1000) if scale else 1.0
bboxes = []
for detection in output_json:
if "bbox_2d" not in detection:
continue
if len(detection["bbox_2d"]) != 4:
continue
if "label" not in detection:
continue
xmin, ymin, xmax, ymax = detection["bbox_2d"]
label = detection.get("label", "")
bbox = [
int(xmin * x_scale),
int(ymin * y_scale),
int(xmax * x_scale),
int(ymax * y_scale),
]
bboxes.append((bbox, label))
return [(image, bboxes), str(output_text)]
# Connect the button to the detection function
run_button.click(
fn=run,
inputs=[
image_input,
system_prompt,
user_prompt,
input_model_id,
max_new_tokens,
],
outputs=[
output_annotated_image,
output_text,
],
)
if __name__ == "__main__":
demo.launch(
# share=True,
)