prithivMLmods commited on
Commit
153a20d
·
verified ·
1 Parent(s): 383ed70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +435 -87
app.py CHANGED
@@ -1,17 +1,26 @@
1
- import gradio as gr
2
- import torch
3
- import spaces
4
  import json
5
  import ast
6
  import re
7
  from threading import Thread
8
  from PIL import Image
 
 
 
9
  from transformers import (
10
  Qwen3_5ForConditionalGeneration,
11
  AutoProcessor,
12
  TextIteratorStreamer,
13
  )
14
 
 
 
 
 
 
 
 
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
  DTYPE = (
17
  torch.bfloat16
@@ -23,58 +32,31 @@ MODEL_NAME = "Qwen/Qwen3.5-2B"
23
  CATEGORIES = ["Query", "Caption", "Point", "Detect"]
24
 
25
  print(f"Loading model: {MODEL_NAME} ...")
26
- qwen_model = Qwen3_5ForConditionalGeneration.from_pretrained(
27
- MODEL_NAME, torch_dtype=DTYPE, device_map=DEVICE,
28
- ).eval()
29
- qwen_processor = AutoProcessor.from_pretrained(MODEL_NAME)
30
- print("Model loaded.")
31
-
32
-
33
- def safe_parse_json(text: str):
34
- text = text.strip()
35
- text = re.sub(r"^```(json)?", "", text)
36
- text = re.sub(r"```$", "", text)
37
- text = text.strip()
38
- try:
39
- return json.loads(text)
40
- except json.JSONDecodeError:
41
- pass
42
- try:
43
- return ast.literal_eval(text)
44
- except Exception:
45
- return {}
46
-
47
-
48
- def on_category_change(category: str):
49
- placeholders = {
50
- "Query": "e.g., Count the total number of boats and describe the environment.",
51
- "Caption": "e.g., short, normal, detailed",
52
- "Point": "e.g., The gun held by the person.",
53
- "Detect": "e.g., The headlight of the car.",
54
- }
55
- return gr.Textbox(placeholder=placeholders.get(category, "Enter your prompt here."))
56
-
57
-
58
- @spaces.GPU
59
- def process_inputs(image, category, prompt):
60
- if image is None:
61
- raise gr.Error("Please upload an image.")
62
- if not prompt or not prompt.strip():
63
- raise gr.Error("Please provide a prompt.")
64
-
65
- image = image.convert("RGB")
66
- image.thumbnail((512, 512))
67
 
 
 
68
  if category == "Query":
69
- full_prompt = prompt
70
  elif category == "Caption":
71
- full_prompt = f"Provide a {prompt} length caption for the image."
72
  elif category == "Point":
73
- full_prompt = f"Provide 2d point coordinates for {prompt}. Report in JSON format."
74
  elif category == "Detect":
75
- full_prompt = f"Provide bounding box coordinates for {prompt}. Report in JSON format."
76
- else:
77
- full_prompt = prompt
 
 
 
 
78
 
79
  messages = [
80
  {
@@ -85,6 +67,7 @@ def process_inputs(image, category, prompt):
85
  ],
86
  }
87
  ]
 
88
  text = qwen_processor.apply_chat_template(
89
  messages, tokenize=False, add_generation_prompt=True
90
  )
@@ -98,6 +81,7 @@ def process_inputs(image, category, prompt):
98
  skip_special_tokens=True,
99
  timeout=120,
100
  )
 
101
  thread = Thread(
102
  target=qwen_model.generate,
103
  kwargs=dict(
@@ -111,48 +95,412 @@ def process_inputs(image, category, prompt):
111
  )
112
  thread.start()
113
 
114
- full_text = ""
115
  for tok in streamer:
116
- full_text += tok
117
- yield full_text
118
 
119
  thread.join()
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- with gr.Blocks() as demo:
123
-
124
- gr.Markdown("## Qwen 3.5 - Image Understanding")
125
-
126
- with gr.Row():
127
- with gr.Column():
128
- image_input = gr.Image(type="pil", label="Upload Image", height=350)
129
- category_select = gr.Dropdown(
130
- choices=CATEGORIES,
131
- value="Query",
132
- label="Task Category",
133
- interactive=True,
134
- )
135
- prompt_input = gr.Textbox(
136
- placeholder="e.g., Count the total number of boats and describe the environment.",
137
- label="Prompt",
138
- lines=3,
139
- )
140
- run_btn = gr.Button("Run", variant="primary")
141
-
142
- with gr.Column():
143
- output_text = gr.Textbox(label="Output", lines=20, interactive=False)
144
-
145
- category_select.change(
146
- fn=on_category_change,
147
- inputs=[category_select],
148
- outputs=[prompt_input],
149
- )
150
- run_btn.click(
151
- fn=process_inputs,
152
- inputs=[image_input, category_select, prompt_input],
153
- outputs=[output_text],
154
  )
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- if __name__ == "__main__":
158
- demo.launch(show_error=True, ssr_mode=False)
 
1
+ import os
2
+ import io
 
3
  import json
4
  import ast
5
  import re
6
  from threading import Thread
7
  from PIL import Image
8
+
9
+ import torch
10
+ import spaces
11
  from transformers import (
12
  Qwen3_5ForConditionalGeneration,
13
  AutoProcessor,
14
  TextIteratorStreamer,
15
  )
16
 
17
+ from gradio import Server
18
+ from fastapi import Request, UploadFile, File, Form, HTTPException
19
+ from fastapi.responses import HTMLResponse, StreamingResponse
20
+
21
+ # --- App Configuration & Initializations ---
22
+ app = Server()
23
+
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
  DTYPE = (
26
  torch.bfloat16
 
32
  CATEGORIES = ["Query", "Caption", "Point", "Detect"]
33
 
34
  print(f"Loading model: {MODEL_NAME} ...")
35
+ try:
36
+ qwen_model = Qwen3_5ForConditionalGeneration.from_pretrained(
37
+ MODEL_NAME, torch_dtype=DTYPE, device_map=DEVICE,
38
+ ).eval()
39
+ qwen_processor = AutoProcessor.from_pretrained(MODEL_NAME)
40
+ print("Model loaded successfully.")
41
+ except Exception as e:
42
+ print(f"Warning: Model failed to load (ignoring if building environment). Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # --- Helper Functions ---
45
+ def process_prompt_by_category(category: str, prompt: str) -> str:
46
  if category == "Query":
47
+ return prompt
48
  elif category == "Caption":
49
+ return f"Provide a {prompt} length caption for the image."
50
  elif category == "Point":
51
+ return f"Provide 2d point coordinates for {prompt}. Report in JSON format."
52
  elif category == "Detect":
53
+ return f"Provide bounding box coordinates for {prompt}. Report in JSON format."
54
+ return prompt
55
+
56
+ # --- Generator with ZeroGPU Space wrapper ---
57
+ @spaces.GPU(duration=120)
58
+ def generate_stream(image: Image.Image, category: str, prompt: str):
59
+ full_prompt = process_prompt_by_category(category, prompt)
60
 
61
  messages = [
62
  {
 
67
  ],
68
  }
69
  ]
70
+
71
  text = qwen_processor.apply_chat_template(
72
  messages, tokenize=False, add_generation_prompt=True
73
  )
 
81
  skip_special_tokens=True,
82
  timeout=120,
83
  )
84
+
85
  thread = Thread(
86
  target=qwen_model.generate,
87
  kwargs=dict(
 
95
  )
96
  thread.start()
97
 
 
98
  for tok in streamer:
99
+ yield tok
 
100
 
101
  thread.join()
102
 
103
+ # --- FastAPI Endpoints ---
104
+ @app.post("/api/run")
105
+ async def run_node_graph(
106
+ image: UploadFile = File(...),
107
+ category: str = Form(...),
108
+ prompt: str = Form(...)
109
+ ):
110
+ if not image:
111
+ raise HTTPException(status_code=400, detail="Image is required")
112
+ if not prompt.strip():
113
+ raise HTTPException(status_code=400, detail="Prompt is required")
114
+ if category not in CATEGORIES:
115
+ raise HTTPException(status_code=400, detail="Invalid Category")
116
 
117
+ try:
118
+ image_bytes = await image.read()
119
+ pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
120
+ pil_image.thumbnail((512, 512)) # Downscale to fit limits
121
+ except Exception as e:
122
+ raise HTTPException(status_code=400, detail=f"Invalid image file: {e}")
123
+
124
+ # Return a StreamingResponse to stream tokens to the frontend
125
+ return StreamingResponse(
126
+ generate_stream(pil_image, category, prompt),
127
+ media_type="text/plain"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  )
129
 
130
+ @app.get("/", response_class=HTMLResponse)
131
+ async def homepage(request: Request):
132
+ return """
133
+ <!DOCTYPE html>
134
+ <html lang="en">
135
+ <head>
136
+ <meta charset="UTF-8">
137
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
138
+ <title>Multimodal-Edge-Comparator</title>
139
+ <style>
140
+ :root {
141
+ --bg-color: #1e1e1e;
142
+ --grid-color: #2a2a2a;
143
+ --node-bg: #333333;
144
+ --node-border: #444444;
145
+ --text-color: #eeeeee;
146
+ --port-color: #64b5f6;
147
+ --wire-color: #81c784;
148
+
149
+ --title-input: #e53935;
150
+ --title-task: #1e88e5;
151
+ --title-output: #43a047;
152
+ }
153
+
154
+ body {
155
+ margin: 0;
156
+ padding: 0;
157
+ background-color: var(--bg-color);
158
+ background-image:
159
+ linear-gradient(var(--grid-color) 1px, transparent 1px),
160
+ linear-gradient(90deg, var(--grid-color) 1px, transparent 1px);
161
+ background-size: 20px 20px;
162
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
163
+ color: var(--text-color);
164
+ overflow: hidden;
165
+ width: 100vw;
166
+ height: 100vh;
167
+ }
168
+
169
+ #topbar {
170
+ position: absolute;
171
+ top: 0; left: 0; right: 0;
172
+ background: rgba(0,0,0,0.6);
173
+ padding: 10px 20px;
174
+ font-size: 18px;
175
+ font-weight: bold;
176
+ color: #ccc;
177
+ z-index: 1000;
178
+ pointer-events: none;
179
+ display: flex;
180
+ justify-content: space-between;
181
+ }
182
+ #topbar a { color: #fff; text-decoration: none; }
183
+
184
+ /* SVG Wire Canvas */
185
+ #wire-canvas {
186
+ position: absolute;
187
+ top: 0; left: 0;
188
+ width: 100%; height: 100%;
189
+ pointer-events: none;
190
+ z-index: 1;
191
+ }
192
+
193
+ .wire-path {
194
+ fill: none;
195
+ stroke: var(--wire-color);
196
+ stroke-width: 4;
197
+ stroke-linecap: round;
198
+ }
199
+
200
+ /* Nodes */
201
+ .node {
202
+ position: absolute;
203
+ background: var(--node-bg);
204
+ border: 1px solid var(--node-border);
205
+ border-radius: 8px;
206
+ box-shadow: 0 4px 10px rgba(0,0,0,0.5);
207
+ min-width: 250px;
208
+ z-index: 10;
209
+ display: flex;
210
+ flex-direction: column;
211
+ }
212
+
213
+ .node-header {
214
+ padding: 8px 12px;
215
+ font-weight: bold;
216
+ font-size: 14px;
217
+ border-top-left-radius: 7px;
218
+ border-top-right-radius: 7px;
219
+ cursor: grab;
220
+ user-select: none;
221
+ color: white;
222
+ text-shadow: 1px 1px 2px rgba(0,0,0,0.8);
223
+ }
224
+ .node-header:active { cursor: grabbing; }
225
+
226
+ .node-content {
227
+ padding: 15px;
228
+ display: flex;
229
+ flex-direction: column;
230
+ gap: 10px;
231
+ }
232
+
233
+ /* Ports */
234
+ .port {
235
+ width: 14px; height: 14px;
236
+ background: var(--port-color);
237
+ border-radius: 50%;
238
+ position: absolute;
239
+ top: 50%;
240
+ transform: translateY(-50%);
241
+ border: 2px solid var(--node-bg);
242
+ z-index: 15;
243
+ }
244
+ .port-out { right: -8px; }
245
+ .port-in { left: -8px; }
246
+
247
+ /* Controls */
248
+ input[type="file"], select, textarea, button {
249
+ width: 100%;
250
+ box-sizing: border-box;
251
+ background: #222;
252
+ color: white;
253
+ border: 1px solid #555;
254
+ padding: 8px;
255
+ border-radius: 4px;
256
+ font-family: inherit;
257
+ }
258
+ textarea { resize: vertical; min-height: 60px; }
259
+ button {
260
+ background: #1e88e5;
261
+ font-weight: bold;
262
+ cursor: pointer;
263
+ transition: background 0.2s;
264
+ border: none;
265
+ }
266
+ button:hover { background: #1565c0; }
267
+ button:disabled { background: #555; cursor: not-allowed; }
268
+
269
+ .image-preview {
270
+ width: 100%;
271
+ height: 180px;
272
+ background: #111;
273
+ border-radius: 4px;
274
+ object-fit: contain;
275
+ display: none;
276
+ }
277
+
278
+ #output-text {
279
+ min-height: 150px;
280
+ max-height: 300px;
281
+ overflow-y: auto;
282
+ white-space: pre-wrap;
283
+ font-family: monospace;
284
+ font-size: 13px;
285
+ color: #aed581;
286
+ background: #111;
287
+ padding: 10px;
288
+ border-radius: 4px;
289
+ }
290
+
291
+ .label { font-size: 12px; color: #aaa; margin-bottom: 2px; }
292
+ .control-group { display: flex; flex-direction: column; }
293
+ </style>
294
+ </head>
295
+ <body>
296
+
297
+ <div id="topbar">
298
+ <a href="https://huggingface.co/spaces/prithivMLmods/Multimodal-Edge-Comparator" target="_blank">Multimodal-Edge-Comparator UI</a>
299
+ <span>Qwen 3.5 - 2B Backend</span>
300
+ </div>
301
+
302
+ <svg id="wire-canvas">
303
+ <path id="wire1" class="wire-path" d="" />
304
+ <path id="wire2" class="wire-path" d="" />
305
+ </svg>
306
+
307
+ <!-- NODE 1: INPUT -->
308
+ <div class="node" id="node-input" style="top: 150px; left: 100px; width: 280px;">
309
+ <div class="node-header" style="background-color: var(--title-input);">1. Image Input Node</div>
310
+ <div class="node-content">
311
+ <div class="control-group">
312
+ <span class="label">Upload Image</span>
313
+ <input type="file" id="file-input" accept="image/*">
314
+ </div>
315
+ <img id="img-preview" class="image-preview">
316
+ </div>
317
+ <div class="port port-out" id="port-input-out"></div>
318
+ </div>
319
+
320
+ <!-- NODE 2: TASK & PROMPT -->
321
+ <div class="node" id="node-task" style="top: 150px; left: 450px; width: 300px;">
322
+ <div class="port port-in" id="port-task-in"></div>
323
+ <div class="node-header" style="background-color: var(--title-task);">2. Processing Node</div>
324
+ <div class="node-content">
325
+ <div class="control-group">
326
+ <span class="label">Task Category</span>
327
+ <select id="category-select">
328
+ <option value="Query">Query</option>
329
+ <option value="Caption">Caption</option>
330
+ <option value="Point">Point</option>
331
+ <option value="Detect">Detect</option>
332
+ </select>
333
+ </div>
334
+ <div class="control-group">
335
+ <span class="label">Prompt Details</span>
336
+ <textarea id="prompt-input" placeholder="e.g., Count the total number of boats and describe the environment."></textarea>
337
+ </div>
338
+ <button id="run-btn">Queue Run</button>
339
+ </div>
340
+ <div class="port port-out" id="port-task-out"></div>
341
+ </div>
342
+
343
+ <!-- NODE 3: OUTPUT -->
344
+ <div class="node" id="node-output" style="top: 150px; left: 820px; width: 350px;">
345
+ <div class="port port-in" id="port-output-in"></div>
346
+ <div class="node-header" style="background-color: var(--title-output);">3. Text Output Node</div>
347
+ <div class="node-content">
348
+ <div class="control-group">
349
+ <span class="label">Streamed Results</span>
350
+ <div id="output-text">Awaiting execution...</div>
351
+ </div>
352
+ </div>
353
+ </div>
354
+
355
+ <script>
356
+ // --- 1. Draggable Nodes Logic ---
357
+ let draggedNode = null;
358
+ let offsetX = 0, offsetY = 0;
359
+
360
+ document.querySelectorAll('.node-header').forEach(header => {
361
+ header.addEventListener('mousedown', (e) => {
362
+ draggedNode = e.target.closest('.node');
363
+ const rect = draggedNode.getBoundingClientRect();
364
+ offsetX = e.clientX - rect.left;
365
+ offsetY = e.clientY - rect.top;
366
+ draggedNode.style.zIndex = 100; // bring to front
367
+ });
368
+ });
369
+
370
+ document.addEventListener('mousemove', (e) => {
371
+ if (!draggedNode) return;
372
+ const x = e.clientX - offsetX;
373
+ const y = e.clientY - offsetY;
374
+ draggedNode.style.left = `${x}px`;
375
+ draggedNode.style.top = `${y}px`;
376
+ updateWires();
377
+ });
378
+
379
+ document.addEventListener('mouseup', () => {
380
+ if (draggedNode) {
381
+ draggedNode.style.zIndex = 10;
382
+ draggedNode = null;
383
+ }
384
+ });
385
+
386
+ // --- 2. Wire Connection Logic ---
387
+ function getPortCenter(portId) {
388
+ const el = document.getElementById(portId);
389
+ const rect = el.getBoundingClientRect();
390
+ return {
391
+ x: rect.left + rect.width / 2,
392
+ y: rect.top + rect.height / 2
393
+ };
394
+ }
395
+
396
+ function drawCurve(x1, y1, x2, y2) {
397
+ // Cubic bezier curve for a "ComfyUI wire" look
398
+ const cx = (x1 + x2) / 2;
399
+ return `M ${x1} ${y1} C ${cx} ${y1}, ${cx} ${y2}, ${x2} ${y2}`;
400
+ }
401
+
402
+ function updateWires() {
403
+ // Wire 1: Input to Task
404
+ const p1 = getPortCenter('port-input-out');
405
+ const p2 = getPortCenter('port-task-in');
406
+ document.getElementById('wire1').setAttribute('d', drawCurve(p1.x, p1.y, p2.x, p2.y));
407
+
408
+ // Wire 2: Task to Output
409
+ const p3 = getPortCenter('port-task-out');
410
+ const p4 = getPortCenter('port-output-in');
411
+ document.getElementById('wire2').setAttribute('d', drawCurve(p3.x, p3.y, p4.x, p4.y));
412
+ }
413
+
414
+ // Initialize wires on load
415
+ window.addEventListener('resize', updateWires);
416
+ updateWires();
417
+
418
+
419
+ // --- 3. App Logic (Placeholders & Previews) ---
420
+ const placeholders = {
421
+ "Query": "e.g., Count the total number of boats...",
422
+ "Caption": "e.g., short, normal, detailed",
423
+ "Point": "e.g., The gun held by the person.",
424
+ "Detect": "e.g., The headlight of the car."
425
+ };
426
+
427
+ const fileInput = document.getElementById('file-input');
428
+ const imgPreview = document.getElementById('img-preview');
429
+ const catSelect = document.getElementById('category-select');
430
+ const promptInput = document.getElementById('prompt-input');
431
+ const runBtn = document.getElementById('run-btn');
432
+ const outText = document.getElementById('output-text');
433
+
434
+ let currentFile = null;
435
+
436
+ catSelect.addEventListener('change', (e) => {
437
+ promptInput.placeholder = placeholders[e.target.value] || "";
438
+ });
439
+
440
+ fileInput.addEventListener('change', (e) => {
441
+ const file = e.target.files[0];
442
+ if (file) {
443
+ currentFile = file;
444
+ const url = URL.createObjectURL(file);
445
+ imgPreview.src = url;
446
+ imgPreview.style.display = 'block';
447
+ updateWires(); // Re-adjust lines since node height changed
448
+ }
449
+ });
450
+
451
+
452
+ // --- 4. Execution & Streaming Logic ---
453
+ runBtn.addEventListener('click', async () => {
454
+ if (!currentFile) return alert('Please upload an image first.');
455
+ if (!promptInput.value.trim()) return alert('Please enter a prompt.');
456
+
457
+ runBtn.disabled = true;
458
+ runBtn.innerText = "Running...";
459
+ outText.innerText = "Initializing connection to model...\n";
460
+
461
+ const formData = new FormData();
462
+ formData.append('image', currentFile);
463
+ formData.append('category', catSelect.value);
464
+ formData.append('prompt', promptInput.value);
465
+
466
+ try {
467
+ const response = await fetch('/api/run', {
468
+ method: 'POST',
469
+ body: formData
470
+ });
471
+
472
+ if (!response.ok) {
473
+ const errText = await response.text();
474
+ throw new Error(`Error: ${response.status} - ${errText}`);
475
+ }
476
+
477
+ outText.innerText = ""; // Clear loader
478
+
479
+ // Read the streaming response chunks
480
+ const reader = response.body.getReader();
481
+ const decoder = new TextDecoder("utf-8");
482
+ let done = false;
483
+
484
+ while (!done) {
485
+ const { value, done: readerDone } = await reader.read();
486
+ done = readerDone;
487
+ if (value) {
488
+ const chunk = decoder.decode(value, { stream: true });
489
+ outText.innerText += chunk;
490
+ // Auto-scroll to bottom
491
+ outText.scrollTop = outText.scrollHeight;
492
+ }
493
+ }
494
+ } catch (err) {
495
+ outText.innerText += `\n\n[Execution Failed]\n${err.message}`;
496
+ } finally {
497
+ runBtn.disabled = false;
498
+ runBtn.innerText = "Queue Run";
499
+ }
500
+ });
501
+ </script>
502
+ </body>
503
+ </html>
504
+ """
505
 
506
+ app.launch()