File size: 3,585 Bytes
019b165 3575d8b 019b165 3575d8b 019b165 3575d8b 019b165 3575d8b 019b165 3575d8b 019b165 3575d8b 019b165 3575d8b 019b165 3575d8b 019b165 3575d8b 019b165 3575d8b 019b165 3575d8b 019b165 3575d8b 019b165 3575d8b 019b165 3575d8b 019b165 a6a4f0e 3575d8b a6a4f0e 019b165 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
# Configuration
MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
print(f"Loading {MODEL_ID}...")
# 1. Load Model
# We use bfloat16 (half precision) which is faster than 4-bit for small models
# and fits easily in 16GB or even 8GB VRAM.
try:
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# The min_pixels and max_pixels arguments help control resolution for speed
processor = AutoProcessor.from_pretrained(MODEL_ID, min_pixels=256*28*28, max_pixels=1280*28*28)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
print("Ensure you have a GPU available.")
exit()
def chat_response(message, history, image_input):
"""
Main generation function called by Gradio.
"""
if image_input is None:
return "Please upload an image first to chat about it!"
# 2. Prepare the messages for Qwen2-VL
# Qwen expects a specific format: a list of messages with specific 'type' keys
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image_input, # Pass the PIL image directly
},
{"type": "text", "text": message},
],
}
]
# 3. Process inputs
# qwen_vl_utils helps process the image and text into tensors
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Move inputs to the same device as the model
inputs = inputs.to(model.device)
# 4. Generate Response
# We limit max_new_tokens to 200 for speed
generated_ids = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
top_p=0.9
)
# 5. Decode output
# We trim the input tokens from the output to get only the new response
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
response = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return response
# --- Gradio UI Setup ---
with gr.Blocks(title="Qwen2-VL Chat", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🚀 Qwen2-VL-2B: Fast Image Chat")
gr.Markdown("Upload an image and ask questions. This 2B model is significantly faster than LLaVA-7B.")
with gr.Row():
with gr.Column(scale=1):
image_box = gr.Image(type="pil", label="Upload Image")
with gr.Column(scale=2):
chatbot = gr.ChatInterface(
fn=chat_response,
additional_inputs=[image_box],
title="Chat",
description="Ask about the uploaded image.",
examples=[
["What is in this image?", None],
["Describe the lighting.", None],
["Read the text in the image.", None],
],
)
if __name__ == "__main__":
demo.queue().launch() |