File size: 3,384 Bytes
f7c4f4f
 
1460a51
 
d4489c6
 
fd75203
f7c4f4f
2ee66d7
d6321ed
7db7aa3
d4489c6
1460a51
2ee66d7
1460a51
0a48e59
fd75203
d4489c6
fd75203
1460a51
 
fd75203
1460a51
d4489c6
de2489e
 
fd75203
 
de2489e
a8f883c
 
 
1460a51
 
 
 
 
 
fd75203
d4489c6
1460a51
 
 
 
 
fd75203
1460a51
 
 
fd75203
0a48e59
1460a51
0a48e59
fd75203
1460a51
 
0a48e59
 
 
1460a51
 
fd75203
d4489c6
fd75203
 
1460a51
 
 
 
 
 
fd75203
1460a51
 
 
 
 
 
 
fd75203
d4489c6
1460a51
 
d4489c6
 
1460a51
 
 
 
d4489c6
 
1460a51
d4489c6
 
fd75203
d4489c6
 
 
fd75203
d4489c6
 
 
 
 
 
fd75203
d4489c6
 
 
 
 
 
 
fd75203
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from transformers import AutoProcessor, AutoModelForVision2Seq
from qwen_vl_utils import process_vision_info
import gradio as gr
from PIL import Image
import torch

# Load 72B AWQ model
model2 = AutoModelForVision2Seq.from_pretrained(
    "Qwen/Qwen2.5-VL-32B-Instruct",
    dtype=torch.float16,
    device_map="auto"
)

processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-32B-Instruct")

# Game rules in German
GAME_RULES = """In diesem Bild sehen Sie drei Farbraster. In der folgenden Äußerung beschreibt der Sprecher genau eines der Gitter.
Bitte geben Sie mir an, ob er sich auf das
linke, mittlere oder rechte Farbraster bezieht.
"""

# Load seven images 
IMAGE_OPTIONS = {
    "Bild 1": "example1.jpg",
    "Bild 2": "example2.jpg",
    "Bild 3": "example3.jpg",
    "Bild 4": "example4.jpg",
    "Bild 5": "example5.jpg",
    "Bild 6": "example6.jpg",
    "Bild 7": "example7.jpg",
    "Bild 8": "example8.jpg",
    "Bild 9": "example9.jpg"
}

# Function to run model
def play_game(selected_image_label, user_prompt):
    selected_image_path = IMAGE_OPTIONS[selected_image_label]
    selected_image = Image.open(selected_image_path)

    # Build messages
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": selected_image},
                {"type": "text", "text": GAME_RULES + "\n" + (user_prompt or "")},
            ],
        }
    ]

    # Prepare input using Qwen's utility function
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)  # Use Qwen utility!

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to(model2.device)

    # Run generation
    with torch.inference_mode():
        generated_ids = model2.generate(**inputs, max_new_tokens=512)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    return output_text

# Gradio App
with gr.Blocks() as demo:
    with gr.Column():
        image_selector = gr.Dropdown(
            choices=list(IMAGE_OPTIONS.keys()),
            value="Bild 2",  
            label="Wählen Sie ein Bild"
        )
        image_display = gr.Image(
            value=Image.open(IMAGE_OPTIONS["Bild 2"]),
            label="Bild",
            interactive=False,
            type="pil"
        )
        prompt_input = gr.Textbox(
            value="Beschreibung",
            label="Ihre Beschreibung"
        )
        output_text = gr.Textbox(label="Antwort des Modells")
        play_button = gr.Button("Spiel starten")

    def update_image(selected_label):
        selected_path = IMAGE_OPTIONS[selected_label]
        return Image.open(selected_path)

    # When user changes selection, update image
    image_selector.change(
        fn=update_image,
        inputs=[image_selector],
        outputs=image_display
    )

    # When user clicks play, send inputs to model
    play_button.click(
        fn=play_game,
        inputs=[image_selector, prompt_input],
        outputs=output_text
    )

demo.launch()