prithivMLmods commited on
Commit
aaa8883
·
verified ·
1 Parent(s): de9c364

update app

Browse files
Files changed (1) hide show
  1. app.py +158 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ import json
5
+ import ast
6
+ import re
7
+ from threading import Thread
8
+ from PIL import Image
9
+ from transformers import (
10
+ Qwen3_5ForConditionalGeneration,
11
+ AutoProcessor,
12
+ TextIteratorStreamer,
13
+ )
14
+
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ DTYPE = (
17
+ torch.bfloat16
18
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
19
+ else torch.float16
20
+ )
21
+
22
+ MODEL_NAME = "Qwen/Qwen3.5-2B"
23
+ CATEGORIES = ["Query", "Caption", "Point", "Detect"]
24
+
25
+ print(f"Loading model: {MODEL_NAME} ...")
26
+ qwen_model = Qwen3_5ForConditionalGeneration.from_pretrained(
27
+ MODEL_NAME, torch_dtype=DTYPE, device_map=DEVICE,
28
+ ).eval()
29
+ qwen_processor = AutoProcessor.from_pretrained(MODEL_NAME)
30
+ print("Model loaded.")
31
+
32
+
33
+ def safe_parse_json(text: str):
34
+ text = text.strip()
35
+ text = re.sub(r"^```(json)?", "", text)
36
+ text = re.sub(r"```$", "", text)
37
+ text = text.strip()
38
+ try:
39
+ return json.loads(text)
40
+ except json.JSONDecodeError:
41
+ pass
42
+ try:
43
+ return ast.literal_eval(text)
44
+ except Exception:
45
+ return {}
46
+
47
+
48
+ def on_category_change(category: str):
49
+ placeholders = {
50
+ "Query": "e.g., Count the total number of boats and describe the environment.",
51
+ "Caption": "e.g., short, normal, detailed",
52
+ "Point": "e.g., The gun held by the person.",
53
+ "Detect": "e.g., The headlight of the car.",
54
+ }
55
+ return gr.Textbox(placeholder=placeholders.get(category, "Enter your prompt here."))
56
+
57
+
58
+ @spaces.GPU
59
+ def process_inputs(image, category, prompt):
60
+ if image is None:
61
+ raise gr.Error("Please upload an image.")
62
+ if not prompt or not prompt.strip():
63
+ raise gr.Error("Please provide a prompt.")
64
+
65
+ image = image.convert("RGB")
66
+ image.thumbnail((512, 512))
67
+
68
+ if category == "Query":
69
+ full_prompt = prompt
70
+ elif category == "Caption":
71
+ full_prompt = f"Provide a {prompt} length caption for the image."
72
+ elif category == "Point":
73
+ full_prompt = f"Provide 2d point coordinates for {prompt}. Report in JSON format."
74
+ elif category == "Detect":
75
+ full_prompt = f"Provide bounding box coordinates for {prompt}. Report in JSON format."
76
+ else:
77
+ full_prompt = prompt
78
+
79
+ messages = [
80
+ {
81
+ "role": "user",
82
+ "content": [
83
+ {"type": "image", "image": image},
84
+ {"type": "text", "text": full_prompt},
85
+ ],
86
+ }
87
+ ]
88
+ text = qwen_processor.apply_chat_template(
89
+ messages, tokenize=False, add_generation_prompt=True
90
+ )
91
+ inputs = qwen_processor(
92
+ text=[text], images=[image], return_tensors="pt", padding=True
93
+ ).to(qwen_model.device)
94
+
95
+ streamer = TextIteratorStreamer(
96
+ qwen_processor.tokenizer,
97
+ skip_prompt=True,
98
+ skip_special_tokens=True,
99
+ timeout=120,
100
+ )
101
+ thread = Thread(
102
+ target=qwen_model.generate,
103
+ kwargs=dict(
104
+ **inputs,
105
+ streamer=streamer,
106
+ max_new_tokens=1024,
107
+ use_cache=True,
108
+ temperature=1.5,
109
+ min_p=0.1,
110
+ ),
111
+ )
112
+ thread.start()
113
+
114
+ full_text = ""
115
+ for tok in streamer:
116
+ full_text += tok
117
+ yield full_text
118
+
119
+ thread.join()
120
+
121
+
122
+ with gr.Blocks() as demo:
123
+
124
+ gr.Markdown("## Qwen 3.5 - Image Understanding")
125
+
126
+ with gr.Row():
127
+ with gr.Column():
128
+ image_input = gr.Image(type="pil", label="Upload Image", height=350)
129
+ category_select = gr.Dropdown(
130
+ choices=CATEGORIES,
131
+ value="Query",
132
+ label="Task Category",
133
+ interactive=True,
134
+ )
135
+ prompt_input = gr.Textbox(
136
+ placeholder="e.g., Count the total number of boats and describe the environment.",
137
+ label="Prompt",
138
+ lines=3,
139
+ )
140
+ run_btn = gr.Button("Run", variant="primary")
141
+
142
+ with gr.Column():
143
+ output_text = gr.Textbox(label="Output", lines=20, interactive=False)
144
+
145
+ category_select.change(
146
+ fn=on_category_change,
147
+ inputs=[category_select],
148
+ outputs=[prompt_input],
149
+ )
150
+ run_btn.click(
151
+ fn=process_inputs,
152
+ inputs=[image_input, category_select, prompt_input],
153
+ outputs=[output_text],
154
+ )
155
+
156
+
157
+ if __name__ == "__main__":
158
+ demo.launch(show_error=True, ssr_mode=False)