File size: 3,585 Bytes
019b165
 
3575d8b
 
019b165
 
3575d8b
019b165
3575d8b
019b165
3575d8b
 
 
019b165
3575d8b
 
 
 
019b165
3575d8b
 
 
019b165
 
 
3575d8b
019b165
 
 
 
 
 
 
 
 
3575d8b
 
 
 
 
 
 
 
 
 
 
 
 
 
019b165
3575d8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
019b165
3575d8b
 
019b165
3575d8b
 
 
 
019b165
 
 
 
 
 
3575d8b
 
 
 
 
019b165
3575d8b
 
 
 
019b165
 
 
3575d8b
 
 
019b165
 
 
 
 
 
 
 
 
 
 
a6a4f0e
 
3575d8b
 
a6a4f0e
019b165
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

# Configuration
MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"

print(f"Loading {MODEL_ID}...")

# 1. Load Model
# We use bfloat16 (half precision) which is faster than 4-bit for small models
# and fits easily in 16GB or even 8GB VRAM.
try:
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    
    # The min_pixels and max_pixels arguments help control resolution for speed
    processor = AutoProcessor.from_pretrained(MODEL_ID, min_pixels=256*28*28, max_pixels=1280*28*28)
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    print("Ensure you have a GPU available.")
    exit()

def chat_response(message, history, image_input):
    """
    Main generation function called by Gradio.
    """
    if image_input is None:
        return "Please upload an image first to chat about it!"

    # 2. Prepare the messages for Qwen2-VL
    # Qwen expects a specific format: a list of messages with specific 'type' keys
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image_input, # Pass the PIL image directly
                },
                {"type": "text", "text": message},
            ],
        }
    ]

    # 3. Process inputs
    # qwen_vl_utils helps process the image and text into tensors
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    image_inputs, video_inputs = process_vision_info(messages)
    
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    
    # Move inputs to the same device as the model
    inputs = inputs.to(model.device)

    # 4. Generate Response
    # We limit max_new_tokens to 200 for speed
    generated_ids = model.generate(
        **inputs,
        max_new_tokens=200,
        do_sample=True,
        temperature=0.7,
        top_p=0.9
    )

    # 5. Decode output
    # We trim the input tokens from the output to get only the new response
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    
    response = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    return response

# --- Gradio UI Setup ---
with gr.Blocks(title="Qwen2-VL Chat", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🚀 Qwen2-VL-2B: Fast Image Chat")
    gr.Markdown("Upload an image and ask questions. This 2B model is significantly faster than LLaVA-7B.")
    
    with gr.Row():
        with gr.Column(scale=1):
            image_box = gr.Image(type="pil", label="Upload Image")
        
        with gr.Column(scale=2):
            chatbot = gr.ChatInterface(
                fn=chat_response,
                additional_inputs=[image_box],
                title="Chat",
                description="Ask about the uploaded image.",
                examples=[
                    ["What is in this image?", None],
                    ["Describe the lighting.", None],
                    ["Read the text in the image.", None],
                ],
            )

if __name__ == "__main__":
    demo.queue().launch()