Translsis commited on
Commit
9399c1d
·
verified ·
1 Parent(s): cb253f6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +400 -935
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import os
2
  import cv2
 
3
  import spaces
4
  import gradio as gr
5
  import numpy as np
6
  import torch
7
  import matplotlib
 
8
  from PIL import Image, ImageDraw
9
  from typing import Iterable
10
  from gradio.themes import Soft
@@ -19,24 +21,46 @@ from datetime import datetime
19
  import threading
20
  import queue
21
  import uuid
22
- import shutil
23
- import zipfile
24
 
25
  # ============ THEME SETUP ============
26
  colors.steel_blue = colors.Color(
27
  name="steel_blue",
28
- c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2",
29
- c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C",
30
- c800="#2E5378", c900="#264364", c950="#1E3450"
 
 
 
 
 
 
 
 
31
  )
32
 
33
  class CustomBlueTheme(Soft):
34
- def __init__(self, *, primary_hue=colors.gray, secondary_hue=colors.steel_blue,
35
- neutral_hue=colors.slate, text_size=sizes.text_lg,
36
- font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"),
37
- font_mono=(fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace")):
38
- super().__init__(primary_hue=primary_hue, secondary_hue=secondary_hue,
39
- neutral_hue=neutral_hue, text_size=text_size, font=font, font_mono=font_mono)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  super().set(
41
  background_fill_primary="*primary_50",
42
  background_fill_primary_dark="*primary_900",
@@ -62,255 +86,76 @@ class CustomBlueTheme(Soft):
62
  app_theme = CustomBlueTheme()
63
 
64
  # ============ GLOBAL SETUP ============
65
- device = "cpu" # Force CPU usage
66
- print(f"🖥️ Using device: {device} (CPU mode for stability)")
67
 
 
68
  HISTORY_DIR = "processing_history"
69
- OUTPUTS_DIR = os.path.join(HISTORY_DIR, "outputs")
70
- DOWNLOADS_DIR = os.path.join(HISTORY_DIR, "downloads")
71
- os.makedirs(OUTPUTS_DIR, exist_ok=True)
72
- os.makedirs(DOWNLOADS_DIR, exist_ok=True)
73
  HISTORY_FILE = os.path.join(HISTORY_DIR, "history.json")
74
 
 
75
  processing_queue = queue.Queue()
76
  processing_results = {}
77
 
78
  # Load models
79
- print("⏳ Loading SAM3 Models...")
80
  try:
81
- print(" Loading Image Model...")
82
  IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3").to(device)
83
  IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3")
84
-
85
- print(" Loading Tracker Model...")
86
  TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3").to(device)
87
  TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3")
88
-
89
- print(" Loading Video Model...")
90
- VID_MODEL = Sam3VideoModel.from_pretrained("DiffusionWave/sam3").to(device) # No bfloat16 for CPU
91
  VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("DiffusionWave/sam3")
92
 
93
- print("✅ All models loaded successfully on CPU!")
94
  except Exception as e:
95
- print(f"❌ Error loading models: {e}")
96
  IMG_MODEL = IMG_PROCESSOR = TRK_MODEL = TRK_PROCESSOR = VID_MODEL = VID_PROCESSOR = None
97
 
98
  # ============ HISTORY MANAGEMENT ============
99
  def load_history():
 
100
  if os.path.exists(HISTORY_FILE):
101
  try:
102
- with open(HISTORY_FILE, 'r', encoding='utf-8') as f:
103
  return json.load(f)
104
  except:
105
  return []
106
  return []
107
 
108
- def save_history(item):
109
- history = load_history()
110
- history.insert(0, item)
111
- history = history[:200]
112
- with open(HISTORY_FILE, 'w', encoding='utf-8') as f:
113
- json.dump(history, f, indent=2, ensure_ascii=False)
114
-
115
- def get_history_stats():
116
- history = load_history()
117
- total = len(history)
118
- completed = sum(1 for h in history if h['status'] == 'completed')
119
- errors = sum(1 for h in history if h['status'] == 'error')
120
- types = {}
121
- for h in history:
122
- t = h['type']
123
- types[t] = types.get(t, 0) + 1
124
- return {
125
- 'total': total,
126
- 'completed': completed,
127
- 'errors': errors,
128
- 'success_rate': f"{(completed/total*100):.1f}%" if total > 0 else "0%",
129
- 'types': types
130
- }
131
-
132
- def create_download_package(item_id):
133
- history = load_history()
134
- item = next((h for h in history if h['id'] == item_id), None)
135
-
136
- if not item or item['status'] != 'completed':
137
- return None
138
-
139
- zip_path = os.path.join(DOWNLOADS_DIR, f"{item_id}_results.zip")
140
-
141
- with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
142
- metadata = {
143
- 'job_id': item_id,
144
- 'type': item['type'],
145
- 'prompt': item.get('prompt', 'N/A'),
146
- 'timestamp': item['timestamp'],
147
- 'duration': item.get('duration', 'N/A'),
148
- 'num_objects': item.get('num_objects', 0)
149
- }
150
- zipf.writestr('metadata.json', json.dumps(metadata, indent=2, ensure_ascii=False))
151
-
152
- if item['type'] == 'image':
153
- if item.get('output_path') and os.path.exists(item['output_path']):
154
- zipf.write(item['output_path'], 'overlay.jpg')
155
- if item.get('segmented_files'):
156
- for i, f in enumerate(item['segmented_files'], 1):
157
- if os.path.exists(f):
158
- zipf.write(f, f'objects/object_{i}.png')
159
-
160
- elif item['type'] == 'video':
161
- if item.get('output_path') and os.path.exists(item['output_path']):
162
- zipf.write(item['output_path'], 'overlay_video.mp4')
163
- if item.get('mask_video_path') and os.path.exists(item['mask_video_path']):
164
- zipf.write(item['mask_video_path'], 'masks_only.mp4')
165
- if item.get('segmented_video_path') and os.path.exists(item['segmented_video_path']):
166
- zipf.write(item['segmented_video_path'], 'segmented_video.mp4')
167
-
168
- elif item['type'] == 'click':
169
- if item.get('output_path') and os.path.exists(item['output_path']):
170
- zipf.write(item['output_path'], 'result.jpg')
171
-
172
- return zip_path
173
-
174
- def get_downloadable_jobs():
175
  history = load_history()
176
- choices = []
177
- for item in history:
178
- if item['status'] == 'completed':
179
- type_emoji = {'image': '📷', 'video': '🎥', 'click': '👆'}.get(item['type'], '📄')
180
- label = f"{type_emoji} [{item['type'].upper()}] {item['prompt'][:35]}... | {item['timestamp']}"
181
- choices.append((label, item['id']))
182
- return choices if choices else [("No completed jobs available", None)]
183
 
184
- def format_history_table():
 
185
  history = load_history()
186
  if not history:
187
- return "<p style='text-align:center; color:#666; padding:40px;'>📭 Chưa có lịch sử xử lý nào</p>"
188
-
189
- html = """
190
- <style>
191
- .history-table { width: 100%; border-collapse: collapse; font-size: 14px; background: white; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
192
- .history-table th { background: linear-gradient(90deg, #4682B4, #529AC3); color: white; padding: 14px 12px; text-align: left; font-weight: 600; text-transform: uppercase; font-size: 12px; letter-spacing: 0.5px; }
193
- .history-table td { padding: 12px; border-bottom: 1px solid #e8e8e8; vertical-align: middle; }
194
- .history-table tr:hover { background-color: #f8f9fa; }
195
- .history-table tr:last-child td { border-bottom: none; }
196
- .status-badge { padding: 5px 12px; border-radius: 14px; font-size: 11px; font-weight: 700; text-transform: uppercase; letter-spacing: 0.5px; display: inline-block; }
197
- .status-completed { background: linear-gradient(135deg, #d4edda, #c3e6cb); color: #155724; }
198
- .status-error { background: linear-gradient(135deg, #f8d7da, #f5c6cb); color: #721c24; }
199
- .type-badge { padding: 5px 10px; border-radius: 10px; font-size: 11px; font-weight: 600; background: linear-gradient(135deg, #e3f2fd, #bbdefb); color: #1565c0; display: inline-block; }
200
- .prompt-text { max-width: 280px; overflow: hidden; text-overflow: ellipsis; white-space: nowrap; color: #333; font-weight: 500; }
201
- .file-count { font-size: 11px; color: #666; margin-top: 4px; line-height: 1.4; }
202
- .job-id { font-family: 'Courier New', monospace; font-size: 10px; color: #999; background: #f5f5f5; padding: 3px 6px; border-radius: 4px; }
203
- .time-info { font-size: 12px; color: #666; }
204
- .duration { font-size: 11px; color: #999; margin-top: 3px; }
205
- </style>
206
- <table class='history-table'>
207
- <thead>
208
- <tr>
209
- <th style='width: 40px; text-align: center;'>#</th>
210
- <th style='width: 110px;'>Job ID</th>
211
- <th style='width: 90px;'>Type</th>
212
- <th style='width: 110px;'>Status</th>
213
- <th>Prompt</th>
214
- <th style='width: 120px;'>Output Files</th>
215
- <th style='width: 140px;'>Time</th>
216
- </tr>
217
- </thead>
218
- <tbody>
219
- """
220
 
221
- for i, item in enumerate(history[:100], 1):
222
- status_class = f"status-{item['status']}"
223
- status_text = "✅ Completed" if item['status'] == 'completed' else "❌ Error"
224
-
225
- type_icons = {'image': '📷', 'video': '🎥', 'click': '👆'}
226
- type_icon = type_icons.get(item['type'], '📄')
227
-
228
- prompt = item.get('prompt', 'N/A')
229
- prompt_short = prompt[:45] + ('...' if len(prompt) > 45 else '')
230
-
231
- file_info = []
232
  if item.get('output_path'):
233
- file_info.append("✓ Overlay")
234
- if item.get('segmented_files'):
235
- file_info.append(f"✓ {len(item['segmented_files'])} Objects")
236
- if item.get('mask_video_path'):
237
- file_info.append("✓ Masks")
238
- if item.get('segmented_video_path'):
239
- file_info.append("✓ Segmented")
240
-
241
- files_text = "<br>".join(file_info) if file_info else "No files"
242
-
243
- html += f"""
244
- <tr>
245
- <td style='text-align: center; font-weight: 600; color: #999;'>{i}</td>
246
- <td><span class='job-id'>{item['id'][:12]}</span></td>
247
- <td><span class='type-badge'>{type_icon} {item['type'].upper()}</span></td>
248
- <td><span class='status-badge {status_class}'>{status_text}</span></td>
249
- <td class='prompt-text' title='{prompt}'>{prompt_short}</td>
250
- <td><div class='file-count'>{files_text}</div></td>
251
- <td>
252
- <div class='time-info'>{item['timestamp']}</div>
253
- <div class='duration'>⏱️ {item.get('duration', 'N/A')}</div>
254
- </td>
255
- </tr>
256
- """
257
-
258
- html += """
259
- </tbody>
260
- </table>
261
- """
262
-
263
- return html
264
 
265
- def get_history_gallery():
266
- history = load_history()
267
- gallery_items = []
268
-
269
- for item in history[:30]:
270
- if item['status'] == 'completed':
271
- if item.get('output_path') and os.path.exists(item['output_path']):
272
- caption = f"[{item['type'].upper()}] {item['prompt'][:35]}... | {item['timestamp']}"
273
- gallery_items.append((item['output_path'], caption))
274
-
275
- return gallery_items if gallery_items else []
276
-
277
- def search_history(keyword, filter_type, filter_status):
278
- history = load_history()
279
- filtered = history
280
-
281
- if keyword:
282
- filtered = [h for h in filtered if keyword.lower() in h.get('prompt', '').lower()]
283
- if filter_type and filter_type != "all":
284
- filtered = [h for h in filtered if h['type'] == filter_type]
285
- if filter_status and filter_status != "all":
286
- filtered = [h for h in filtered if h['status'] == filter_status]
287
-
288
- return filtered
289
-
290
- def clear_all_history():
291
- if os.path.exists(OUTPUTS_DIR):
292
- shutil.rmtree(OUTPUTS_DIR)
293
- os.makedirs(OUTPUTS_DIR)
294
- if os.path.exists(DOWNLOADS_DIR):
295
- shutil.rmtree(DOWNLOADS_DIR)
296
- os.makedirs(DOWNLOADS_DIR)
297
-
298
- with open(HISTORY_FILE, 'w', encoding='utf-8') as f:
299
- json.dump([], f)
300
-
301
- return "✅ Đã xóa toàn bộ lịch sử và files"
302
-
303
- def export_history_json():
304
- history = load_history()
305
- export_path = os.path.join(HISTORY_DIR, f"history_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
306
-
307
- with open(export_path, 'w', encoding='utf-8') as f:
308
- json.dump(history, f, indent=2, ensure_ascii=False)
309
-
310
- return export_path
311
-
312
- # ============ PROCESSING UTILS ============
313
  def apply_mask_overlay(base_image, mask_data, opacity=0.5):
 
314
  if isinstance(base_image, np.ndarray):
315
  base_image = Image.fromarray(base_image)
316
  base_image = base_image.convert("RGBA")
@@ -322,10 +167,8 @@ def apply_mask_overlay(base_image, mask_data, opacity=0.5):
322
  mask_data = mask_data.cpu().numpy()
323
  mask_data = mask_data.astype(np.uint8)
324
 
325
- if mask_data.ndim == 4:
326
- mask_data = mask_data[0]
327
- if mask_data.ndim == 3 and mask_data.shape[0] == 1:
328
- mask_data = mask_data[0]
329
 
330
  num_masks = mask_data.shape[0] if mask_data.ndim == 3 else 1
331
  if mask_data.ndim == 2:
@@ -334,262 +177,44 @@ def apply_mask_overlay(base_image, mask_data, opacity=0.5):
334
 
335
  try:
336
  color_map = matplotlib.colormaps["rainbow"].resampled(max(num_masks, 1))
337
- except:
338
  import matplotlib.cm as cm
339
  color_map = cm.get_cmap("rainbow").resampled(max(num_masks, 1))
340
 
341
  rgb_colors = [tuple(int(c * 255) for c in color_map(i)[:3]) for i in range(num_masks)]
342
  composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0))
343
 
344
- for i, mask in enumerate(mask_data):
345
- mask_img = Image.fromarray((mask * 255).astype(np.uint8))
346
- if mask_img.size != base_image.size:
347
- mask_img = mask_img.resize(base_image.size, resample=Image.NEAREST)
348
 
349
- color_fill = Image.new("RGBA", base_image.size, rgb_colors[i] + (0,))
350
- mask_alpha = mask_img.point(lambda v: int(v * opacity) if v > 0 else 0)
 
351
  color_fill.putalpha(mask_alpha)
352
  composite_layer = Image.alpha_composite(composite_layer, color_fill)
353
 
354
  return Image.alpha_composite(base_image, composite_layer).convert("RGB")
355
 
356
  def draw_points_on_image(image, points):
 
357
  if isinstance(image, np.ndarray):
358
  image = Image.fromarray(image)
 
359
  draw_img = image.copy()
360
  draw = ImageDraw.Draw(draw_img)
361
- for x, y in points:
 
 
362
  r = 8
363
  draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white", width=4)
364
- return draw_img
365
-
366
- # ============ JOB PROCESSORS ============
367
- def process_image_job(job):
368
- start = datetime.now()
369
- img = job['image']
370
- if isinstance(img, str):
371
- img = Image.open(img)
372
-
373
- img = img.convert("RGB")
374
- inputs = IMG_PROCESSOR(images=img, text=job['prompt'], return_tensors="pt").to(device)
375
-
376
- with torch.no_grad():
377
- outputs = IMG_MODEL(**inputs)
378
-
379
- results = IMG_PROCESSOR.post_process_instance_segmentation(
380
- outputs,
381
- threshold=job.get('conf_thresh', 0.5),
382
- mask_threshold=0.5,
383
- target_sizes=inputs.get("original_sizes").tolist()
384
- )[0]
385
-
386
- masks = results['masks'].cpu().numpy()
387
- scores = results['scores'].cpu().numpy()
388
- annotations = [(m, f"{job['prompt']} ({s:.2f})") for m, s in zip(masks, scores)]
389
-
390
- out_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_overlay.jpg")
391
- apply_mask_overlay(img, masks).save(out_path)
392
-
393
- seg_files = []
394
- for i, mask in enumerate(masks):
395
- mask_bool = mask.astype(bool)
396
- seg = Image.new("RGBA", img.size, (0, 0, 0, 0))
397
- arr = np.array(img.convert("RGBA"))
398
- arr[~mask_bool] = [0, 0, 0, 0]
399
- seg = Image.fromarray(arr)
400
-
401
- # Fix: Convert mask to uint8 before creating Image
402
- mask_uint8 = (mask * 255).astype(np.uint8)
403
- bbox = Image.fromarray(mask_uint8).getbbox()
404
-
405
- if bbox:
406
- seg_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_obj_{i+1}.png")
407
- seg.crop(bbox).save(seg_path)
408
- seg_files.append(seg_path)
409
-
410
- return {
411
- 'image': (img, annotations),
412
- 'output_path': out_path,
413
- 'segmented_files': seg_files,
414
- 'num_objects': len(seg_files),
415
- 'duration': f"{(datetime.now() - start).total_seconds():.2f}s"
416
- }
417
-
418
- def process_video_job(job):
419
- """Process video on CPU - slower but no timeout"""
420
- start = datetime.now()
421
- cap = cv2.VideoCapture(job['video'])
422
- fps = cap.get(cv2.CAP_PROP_FPS)
423
- w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
424
-
425
- frames = []
426
- limit = job.get('frame_limit', 60)
427
- if limit == 0 or limit > 500:
428
- limit = 500 # Higher limit for CPU since no GPU timeout
429
-
430
- count = 0
431
- while cap.isOpened():
432
- ret, frame = cap.read()
433
- if not ret or count >= limit:
434
- break
435
- frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
436
- count += 1
437
- cap.release()
438
-
439
- print(f"📹 Processing {len(frames)} frames on CPU (this will take longer)...")
440
-
441
- # Process in chunks to manage memory
442
- chunk_size = 30 # Smaller chunks for CPU
443
-
444
- out_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_overlay.mp4")
445
- mask_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_masks.mp4")
446
- seg_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_segmented.mp4")
447
-
448
- writers = [
449
- cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)),
450
- cv2.VideoWriter(mask_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)),
451
- cv2.VideoWriter(seg_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
452
- ]
453
-
454
- total = len(frames)
455
- processed = 0
456
-
457
- # Process frames in chunks
458
- for chunk_start in range(0, total, chunk_size):
459
- chunk_end = min(chunk_start + chunk_size, total)
460
- chunk_frames = frames[chunk_start:chunk_end]
461
-
462
- print(f"🔄 Processing chunk {chunk_start}-{chunk_end} ({len(chunk_frames)} frames)")
463
-
464
- try:
465
- # Initialize session for this chunk
466
- session = VID_PROCESSOR.init_video_session(
467
- video=chunk_frames,
468
- inference_device=device
469
- )
470
- session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=job['prompt'])
471
-
472
- # Process chunk
473
- for idx, out in enumerate(VID_MODEL.propagate_in_video_iterator(
474
- inference_session=session,
475
- max_frame_num_to_track=len(chunk_frames)
476
- )):
477
- try:
478
- proc = VID_PROCESSOR.postprocess_outputs(session, out)
479
- f_idx = out.frame_idx
480
- orig = Image.fromarray(chunk_frames[f_idx])
481
-
482
- if 'masks' in proc:
483
- masks = proc['masks']
484
- if masks.ndim == 4:
485
- masks = masks.squeeze(1)
486
-
487
- overlay = apply_mask_overlay(orig, masks)
488
- writers[0].write(cv2.cvtColor(np.array(overlay), cv2.COLOR_RGB2BGR))
489
-
490
- mask_np = masks.cpu().numpy() if isinstance(masks, torch.Tensor) else masks
491
- combined = np.zeros((h, w), dtype=np.uint8)
492
- for m in mask_np:
493
- if m.shape != (h, w):
494
- m = cv2.resize(m.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST)
495
- combined = np.maximum(combined, m)
496
-
497
- mask_frame = np.zeros((h, w, 3), dtype=np.uint8)
498
- mask_frame[combined > 0] = [255, 255, 255]
499
- writers[1].write(mask_frame)
500
-
501
- seg_arr = np.array(orig.convert("RGBA"))
502
- seg_arr[:, :, 3] = (combined * 255).astype(np.uint8)
503
- bgr = np.zeros((h, w, 3), dtype=np.uint8)
504
- bgr[:, :] = [0, 255, 0]
505
- for c in range(3):
506
- bgr[:, :, c] = np.where(combined > 0, seg_arr[:, :, 2-c], bgr[:, :, c])
507
- writers[2].write(bgr)
508
- else:
509
- orig_bgr = cv2.cvtColor(np.array(orig), cv2.COLOR_RGB2BGR)
510
- writers[0].write(orig_bgr)
511
- writers[1].write(np.zeros((h, w, 3), dtype=np.uint8))
512
- writers[2].write(orig_bgr)
513
-
514
- processed += 1
515
- progress = int((processed / total) * 100)
516
- processing_results[job['id']]['progress'] = progress
517
-
518
- if processed % 5 == 0:
519
- elapsed = (datetime.now() - start).total_seconds()
520
- avg_time = elapsed / processed
521
- remaining = (total - processed) * avg_time
522
- print(f"⏳ Progress: {progress}% ({processed}/{total}) | ETA: {remaining/60:.1f} min")
523
-
524
- except Exception as e:
525
- print(f"⚠️ Error processing frame {f_idx}: {e}")
526
- orig_bgr = cv2.cvtColor(np.array(orig), cv2.COLOR_RGB2BGR)
527
- writers[0].write(orig_bgr)
528
- writers[1].write(np.zeros((h, w, 3), dtype=np.uint8))
529
- writers[2].write(orig_bgr)
530
- processed += 1
531
-
532
- # Clear memory after each chunk
533
- del session
534
-
535
- except Exception as e:
536
- print(f"❌ Error processing chunk: {e}")
537
- for i in range(chunk_start, chunk_end):
538
- if i < len(frames):
539
- orig_bgr = cv2.cvtColor(frames[i], cv2.COLOR_RGB2BGR)
540
- writers[0].write(orig_bgr)
541
- writers[1].write(np.zeros((h, w, 3), dtype=np.uint8))
542
- writers[2].write(orig_bgr)
543
- processed += 1
544
-
545
- for w in writers:
546
- w.release()
547
 
548
- print(f"✅ Video completed: {processed} frames in {(datetime.now() - start).total_seconds():.2f}s")
549
-
550
- return {
551
- 'output_path': out_path,
552
- 'mask_video_path': mask_path,
553
- 'segmented_video_path': seg_path,
554
- 'duration': f"{(datetime.now() - start).total_seconds():.2f}s"
555
- }
556
-
557
- def process_click_job(job):
558
- start = datetime.now()
559
- img = job['image']
560
- if isinstance(img, str):
561
- img = Image.open(img)
562
-
563
- inputs = TRK_PROCESSOR(
564
- images=img,
565
- input_points=[[job['points']]],
566
- input_labels=[[job['labels']]],
567
- return_tensors="pt"
568
- ).to(device)
569
-
570
- with torch.no_grad():
571
- outputs = TRK_MODEL(**inputs, multimask_output=False)
572
-
573
- masks = TRK_PROCESSOR.post_process_masks(
574
- outputs.pred_masks.cpu(),
575
- inputs["original_sizes"],
576
- binarize=True
577
- )[0]
578
-
579
- result = apply_mask_overlay(img, masks[0])
580
- result = draw_points_on_image(result, job['points'])
581
-
582
- out_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_result.jpg")
583
- result.save(out_path)
584
-
585
- return {
586
- 'image': result,
587
- 'output_path': out_path,
588
- 'duration': f"{(datetime.now() - start).total_seconds():.2f}s"
589
- }
590
 
591
- # ============ BACKGROUND WORKER ============
592
  def background_worker():
 
593
  while True:
594
  try:
595
  job = processing_queue.get()
@@ -599,8 +224,6 @@ def background_worker():
599
  job_id = job['id']
600
  job_type = job['type']
601
 
602
- print(f"🚀 Starting job {job_id[:8]} - Type: {job_type}")
603
-
604
  processing_results[job_id] = {'status': 'processing', 'progress': 0}
605
 
606
  try:
@@ -617,22 +240,17 @@ def background_worker():
617
  'progress': 100
618
  }
619
 
620
- print(f"✅ Job {job_id[:8]} completed successfully")
621
-
622
  save_history({
623
  'id': job_id,
624
  'type': job_type,
625
  'prompt': job.get('prompt', 'N/A'),
626
  'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
627
  'status': 'completed',
628
- **result
629
  })
630
 
631
  except Exception as e:
632
- print(f"❌ Job {job_id[:8]} failed: {str(e)}")
633
- import traceback
634
- traceback.print_exc()
635
-
636
  processing_results[job_id] = {
637
  'status': 'error',
638
  'error': str(e),
@@ -647,513 +265,360 @@ def background_worker():
647
  'error': str(e)
648
  })
649
  except Exception as e:
650
- print(f"⚠️ Worker error: {e}")
651
- import traceback
652
- traceback.print_exc()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
 
654
- threading.Thread(target=background_worker, daemon=True).start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
 
656
- # ============ GRADIO UI ============
657
- custom_css = """
658
- #col-container { margin: 0 auto; max-width: 1400px; }
659
- #main-title h1 { font-size: 2.2em !important; font-weight: 700; background: linear-gradient(135deg, #4682B4, #764ba2); -webkit-background-clip: text; -webkit-text-fill-color: transparent; }
660
- .stat-card { padding: 24px; border-radius: 16px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; text-align: center; box-shadow: 0 4px 15px rgba(0,0,0,0.2); }
661
- .stat-number { font-size: 2.8em; font-weight: 800; margin: 12px 0; text-shadow: 2px 2px 4px rgba(0,0,0,0.2); }
662
- .stat-label { font-size: 1.1em; opacity: 0.95; font-weight: 500; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
  """
664
 
665
- with gr.Blocks(title="SAM3 Segmentation") as demo:
666
  with gr.Column(elem_id="col-container"):
667
  gr.Markdown("# **SAM3: Segment Anything Model 3** 🚀", elem_id="main-title")
668
- gr.Markdown("### 💻 CPU Mode - Xử lý không giới hạn thời gian | Background processing | Download đầy đủ kết quả")
669
-
670
- gr.Markdown("""
671
- <div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; border-radius: 10px; color: white; margin-bottom: 20px;'>
672
- <strong>🔥 Đặc điểm CPU Mode:</strong><br>
673
- ✅ Không bị timeout - xử lý video dài thoải mái<br>
674
- ⏱️ Chậm hơn nhưng ổn định - tốc độ ~2-3 phút/frame<br>
675
- 🔋 Chạy background - submit job và làm việc khác<br>
676
- 💾 Tự động lưu lịch sử và download được
677
- </div>
678
- """)
679
-
680
  with gr.Tabs():
681
- # ===== IMAGE TAB =====
682
  with gr.Tab("📷 Image Segmentation"):
683
  with gr.Row():
684
  with gr.Column(scale=1):
685
- click_input = gr.Image(
686
- type="pil",
687
- label="📤 Upload Image & Click Objects",
688
- interactive=True,
689
- height=450
690
- )
691
- gr.Markdown("""
692
- **📝 Hướng dẫn:**
693
- 1. Upload ảnh
694
- 2. Click vào đối tượng bạn muốn phân đoạn
695
- 3. Kết quả hiển thị ngay lập tức
696
- 4. Click "Clear" để reset và bắt đầu lại
697
- """)
698
-
699
- click_clear = gr.Button("🔄 Clear Points & Reset", variant="primary")
700
-
701
- click_pts = gr.State([])
702
- click_lbl = gr.State([])
703
-
704
- with gr.Column(scale=1):
705
- img_input = gr.Image(label="📤 Upload Image", type="pil", height=350)
706
- img_prompt = gr.Textbox(
707
- label="✍️ Text Prompt",
708
- placeholder="e.g., cat, person, car, building...",
709
- lines=2
710
- )
711
- with gr.Accordion("⚙️ Advanced Settings", open=False):
712
- img_conf = gr.Slider(0.0, 1.0, 0.45, 0.05, label="Confidence Threshold")
713
-
714
- img_submit = gr.Button("🚀 Submit Job (Background)", variant="primary", size="lg")
715
- img_check = gr.Button("🔍 Check Status", variant="secondary")
716
- img_job_id = gr.Textbox(label="Job ID", visible=False)
717
-
718
  with gr.Column(scale=1.5):
719
- img_result = gr.AnnotatedImage(label="🎨 Segmented Result (Overlay)", height=410)
720
- img_status = gr.Textbox(label="📊 Status", interactive=False)
721
-
722
- with gr.Accordion("📦 Extracted Objects", open=True):
723
- gr.Markdown("**Các đối tượng được tách ra (PNG với nền trong suốt):**")
724
- img_gallery = gr.Gallery(
725
- label="Segmented Objects",
726
- columns=3,
727
- height=300,
728
- object_fit="contain"
729
- )
730
-
731
- def submit_img(img, prompt, conf):
732
- if not img or not prompt:
733
- return None, "❌ Vui lòng cung cấp ảnh và prompt", "", []
734
- jid = str(uuid.uuid4())
735
- processing_queue.put({
736
- 'id': jid,
737
- 'type': 'image',
738
- 'image': img,
739
- 'prompt': prompt,
740
- 'conf_thresh': conf
741
- })
742
- return None, f"✅ Đã thêm vào hàng chờ (ID: {jid[:8]}). Đang xử lý...", jid, []
743
 
744
- def check_img(jid):
745
- if not jid or jid not in processing_results:
746
- return None, "❌ Không tìm thấy công việc", []
747
-
748
- r = processing_results[jid]
749
-
750
- if r['status'] == 'processing':
751
- return None, f"⏳ Đang xử lý... {r['progress']}%", []
752
- elif r['status'] == 'completed':
753
- res = r['result']
754
- gal = [f for f in res.get('segmented_files', []) if os.path.exists(f)]
755
- status = f"✅ Hoàn thành! Đã tách được {len(gal)} đối tượng | Thời gian: {res.get('duration', 'N/A')}"
756
- return res['image'], status, gal
757
- else:
758
- return None, f"❌ Lỗi: {r.get('error', 'Unknown')}", []
759
-
760
- img_submit.click(
761
- fn=submit_img,
762
- inputs=[img_input, img_prompt, img_conf],
763
- outputs=[img_result, img_status, img_job_id, img_gallery]
764
  )
765
 
766
- img_check.click(
767
- fn=check_img,
768
- inputs=[img_job_id],
769
- outputs=[img_result, img_status, img_gallery]
770
  )
771
 
772
- # ===== VIDEO TAB =====
773
  with gr.Tab("🎥 Video Segmentation"):
774
  with gr.Row():
775
  with gr.Column():
776
- vid_input = gr.Video(label="📤 Upload Video", format="mp4", height=320)
777
- vid_prompt = gr.Textbox(
778
- label="✍️ Text Prompt",
779
- placeholder="e.g., person running, red car, dog...",
780
- lines=2
781
- )
782
-
783
- with gr.Accordion("⚙️ Settings", open=True):
784
- vid_frames = gr.Slider(
785
- 10, 500, 60, 10,
786
- label="Max Frames",
787
- info="CPU mode: Có thể xử lý nhiều frames hơn, nhưng sẽ chậm hơn"
788
- )
789
-
790
- gr.Markdown("""
791
- **💻 CPU Processing Mode:**
792
- - ✅ **Không bị timeout** - xử lý bao nhiêu cũng được
793
- - ⏱️ **Chậm hơn GPU** - khoảng 2-3 phút/frame
794
- - 🔋 **Ổn định** - không crash, chạy nền background
795
-
796
- **⏱️ Thời gian ước tính:**
797
- - 30 frames: ~60-90 phút
798
- - 60 frames: ~2-3 giờ
799
- - 100 frames: ~3-5 giờ
800
-
801
- **💡 Khuyến nghị:**
802
- - Submit job và làm việc khác
803
- - Nhấn "Check Status" để xem tiến độ
804
- - Video sẽ được lưu khi hoàn thành
805
- """)
806
-
807
- vid_submit = gr.Button("🚀 Submit Job (Background)", variant="primary", size="lg")
808
- vid_check = gr.Button("🔍 Check Status", variant="secondary")
809
- vid_job_id = gr.Textbox(label="Job ID", visible=False)
810
-
811
- with gr.Column():
812
- gr.Markdown("### 📹 Video Outputs (3 versions)")
813
-
814
- with gr.Tabs():
815
- with gr.Tab("1️⃣ Overlay"):
816
- vid_overlay = gr.Video(label="Original + Color Masks")
817
- gr.Markdown("*Video gốc với màu mask phủ lên*")
818
-
819
- with gr.Tab("2️⃣ Masks Only"):
820
- vid_masks = gr.Video(label="White Masks on Black")
821
- gr.Markdown("*Chỉ hiển thị mask màu trắng trên nền đen*")
822
-
823
- with gr.Tab("3️⃣ Segmented"):
824
- vid_segmented = gr.Video(label="Green Screen Background")
825
- gr.Markdown("*Đối tượng với nền xanh lá (green screen)*")
826
-
827
- vid_status = gr.Textbox(label="📊 Status", interactive=False)
828
-
829
- def submit_vid(vid, prompt, frames):
830
- if not vid or not prompt:
831
- return None, None, None, "❌ Vui lòng cung cấp video và prompt", ""
832
- jid = str(uuid.uuid4())
833
- processing_queue.put({
834
- 'id': jid,
835
- 'type': 'video',
836
- 'video': vid,
837
- 'prompt': prompt,
838
- 'frame_limit': frames
839
- })
840
- return None, None, None, f"✅ Đã thêm vào hàng chờ (ID: {jid[:8]}). Đang xử lý...", jid
841
-
842
- def check_vid(jid):
843
- if not jid or jid not in processing_results:
844
- return None, None, None, "❌ Không tìm thấy công việc"
845
-
846
- r = processing_results[jid]
847
-
848
- if r['status'] == 'processing':
849
- return None, None, None, f"⏳ Đang xử lý... {r['progress']}%"
850
- elif r['status'] == 'completed':
851
- res = r['result']
852
- status = f"""✅ Hoàn thành! Thời gian: {res.get('duration', 'N/A')}
853
-
854
- 📹 3 video đã được tạo:
855
- • Overlay - Ảnh gốc với mask màu
856
- • Masks Only - Chỉ mask (trắng/đen)
857
- • Segmented - Đối tượng với green screen"""
858
- return (
859
- res.get('output_path'),
860
- res.get('mask_video_path'),
861
- res.get('segmented_video_path'),
862
- status
863
- )
864
- else:
865
- return None, None, None, f"❌ Lỗi: {r.get('error', 'Unknown')}"
866
-
867
- vid_submit.click(
868
- fn=submit_vid,
869
- inputs=[vid_input, vid_prompt, vid_frames],
870
- outputs=[vid_overlay, vid_masks, vid_segmented, vid_status, vid_job_id]
871
- )
872
-
873
- vid_check.click(
874
- fn=check_vid,
875
- inputs=[vid_job_id],
876
- outputs=[vid_overlay, vid_masks, vid_segmented, vid_status]
877
- )
878
-
879
- # ===== CLICK TAB =====
880
- with gr.Tab("👆 Click Segmentation"):
881
- with gr.Row():
882
- with gr.Column(scale=1):
883
- click_input = gr.Image(
884
- type="pil",
885
- label="📤 Upload Image & Click Objects",
886
- interactive=True,
887
- height=450
888
- )
889
- gr.Markdown("""
890
- **📝 Hướng dẫn:**
891
- 1. Upload ảnh
892
- 2. Click vào đối tượng bạn muốn phân đoạn
893
- 3. Kết quả hiển thị ngay lập tức
894
- 4. Click "Clear" để reset và bắt đầu lại
895
- """)
896
 
897
- click_clear = gr.Button("🔄 Clear Points & Reset", variant="primary")
 
 
898
 
899
- click_pts = gr.State([])
900
- click_lbl = gr.State([])
901
-
902
- with gr.Column(scale=1):
903
- click_output = gr.Image(
904
- type="pil",
905
- label="🎨 Result Preview",
906
- height=450,
907
- interactive=False
908
- )
909
 
910
- gr.Markdown("""
911
- **💡 Tips:**
912
- - Click vào trung tâm của đối tượng để có kết quả tốt nhất
913
- - Các điểm click được hiển thị bằng dấu chấm đỏ
914
- - Kết quả tự động cập nhật sau mỗi lần click
915
- """)
916
-
917
- def on_click(img, evt: gr.SelectData, pts, lbl):
918
- if pts is None:
919
- pts = []
920
- if lbl is None:
921
- lbl = []
922
-
923
- pts.append([evt.index[0], evt.index[1]])
924
- lbl.append(1)
925
-
926
- jid = str(uuid.uuid4())
927
- try:
928
- res = process_click_job({
929
- 'id': jid,
930
- 'type': 'click',
931
- 'image': img,
932
- 'points': pts,
933
- 'labels': lbl
934
- })
935
- return res['image'], pts, lbl
936
- except Exception as e:
937
- print(f"Click error: {e}")
938
- return img, pts, lbl
939
-
940
- click_input.select(
941
- fn=on_click,
942
- inputs=[click_input, click_pts, click_lbl],
943
- outputs=[click_output, click_pts, click_lbl]
944
  )
945
 
946
- click_clear.click(
947
- fn=lambda: (None, [], []),
948
- outputs=[click_output, click_pts, click_lbl]
 
949
  )
950
-
951
- # ===== DOWNLOAD TAB =====
952
- with gr.Tab("📥 Download Results"):
953
- gr.Markdown("""
954
- # 📦 Download Center
955
- ### Tải về kết quả đã xử lý dưới dạng ZIP
956
- """)
957
 
 
 
958
  with gr.Row():
959
  with gr.Column(scale=1):
960
- gr.Markdown("### 🎯 Select Job to Download")
961
-
962
- download_dropdown = gr.Dropdown(
963
- label="Chọn công việc đã hoàn thành",
964
- choices=get_downloadable_jobs(),
965
- interactive=True,
966
- scale=1
967
- )
968
 
969
  with gr.Row():
970
- download_refresh = gr.Button("🔄 Refresh List", variant="secondary", scale=1)
971
- download_btn = gr.Button("📥 Download ZIP", variant="primary", size="lg", scale=2)
972
 
973
- download_status = gr.Textbox(label="Status", interactive=False)
974
-
 
975
  with gr.Column(scale=1):
976
- gr.Markdown("### 📄 Download File")
977
- download_file = gr.File(label="Your ZIP file will appear here")
978
-
979
- gr.Markdown("""
980
- **📦 Package Contents:**
981
-
982
- **Image Jobs:**
983
- - `overlay.jpg` - Ảnh với mask màu
984
- - `objects/object_*.png` - Từng đối tượng riêng lẻ (PNG transparent)
985
- - `metadata.json` - Thông tin chi tiết
986
-
987
- **Video Jobs:**
988
- - `overlay_video.mp4` - Video với mask màu
989
- - `masks_only.mp4` - Chỉ mask trắng/đen
990
- - `segmented_video.mp4` - Video với green screen
991
- - `metadata.json` - Thông tin chi tiết
992
-
993
- **Click Jobs:**
994
- - `result.jpg` - Ảnh kết quả
995
- - `metadata.json` - Thông tin chi tiết
996
- """)
997
 
998
- def do_download(job_id):
999
- if not job_id:
1000
- return None, "❌ Vui lòng chọn một job"
1001
-
1002
- zip_path = create_download_package(job_id)
1003
- if zip_path and os.path.exists(zip_path):
1004
- size_mb = os.path.getsize(zip_path) / 1024 / 1024
1005
- return zip_path, f"✅ Sẵn sàng tải về! Kích thước: {size_mb:.2f} MB"
1006
-
1007
- return None, "❌ Không thể tạo package. Job có thể đã bị xóa."
1008
-
1009
- download_refresh.click(
1010
- fn=lambda: gr.Dropdown(choices=get_downloadable_jobs()),
1011
- outputs=[download_dropdown]
1012
  )
1013
 
1014
- download_btn.click(
1015
- fn=do_download,
1016
- inputs=[download_dropdown],
1017
- outputs=[download_file, download_status]
1018
  )
1019
 
1020
- # ===== HISTORY & STATS TAB =====
1021
- with gr.Tab("📊 History & Statistics"):
1022
- with gr.Row():
1023
- with gr.Column(scale=1):
1024
- gr.Markdown("### 📈 Statistics Dashboard")
1025
-
1026
- def update_stats():
1027
- stats = get_history_stats()
1028
- return (
1029
- f"**{stats['total']}**\n\nTổng số jobs",
1030
- f"**{stats['completed']}**\n\nHoàn thành",
1031
- f"**{stats['errors']}**\n\nLỗi",
1032
- f"**{stats['success_rate']}**\n\nTỷ lệ thành công"
1033
- )
1034
-
1035
- with gr.Row():
1036
- stat_total = gr.Markdown("**0**\n\nTổng số jobs", elem_classes=["stat-card"])
1037
- stat_completed = gr.Markdown("**0**\n\nHoàn thành", elem_classes=["stat-card"])
1038
-
1039
- with gr.Row():
1040
- stat_errors = gr.Markdown("**0**\n\nLỗi", elem_classes=["stat-card"])
1041
- stat_success = gr.Markdown("**0%**\n\nTỷ lệ thành công", elem_classes=["stat-card"])
1042
-
1043
- gr.Markdown("### 🎯 Quick Actions")
1044
-
1045
- with gr.Row():
1046
- btn_refresh = gr.Button("🔄 Refresh All", variant="primary")
1047
- btn_export = gr.Button("📥 Export JSON", variant="secondary")
1048
-
1049
- btn_clear_all = gr.Button("🗑️ Clear All History", variant="stop")
1050
-
1051
- export_file = gr.File(label="Exported File", visible=False)
1052
- clear_status = gr.Textbox(label="Status", interactive=False)
1053
-
1054
  with gr.Row():
1055
  with gr.Column():
1056
- gr.Markdown("### 📜 Processing History")
1057
-
1058
- with gr.Row():
1059
- search_input = gr.Textbox(
1060
- placeholder="🔍 Tìm kiếm theo prompt...",
1061
- label="Search",
1062
- scale=2
1063
- )
1064
- filter_type = gr.Dropdown(
1065
- choices=["all", "image", "video", "click"],
1066
- value="all",
1067
- label="Loại",
1068
- scale=1
1069
- )
1070
- filter_status = gr.Dropdown(
1071
- choices=["all", "completed", "error"],
1072
- value="all",
1073
- label="Trạng thái",
1074
- scale=1
1075
- )
1076
 
1077
- history_table = gr.HTML(value=format_history_table())
1078
-
1079
- with gr.Row():
1080
- with gr.Column():
1081
- gr.Markdown("### 🖼️ Gallery - Recent Outputs")
1082
- history_gallery = gr.Gallery(
1083
- value=get_history_gallery(),
1084
- label="Kết quả gần đây",
1085
- columns=4,
1086
- height=400,
1087
- object_fit="contain"
1088
- )
1089
-
1090
- def refresh_all():
1091
- return (
1092
- *update_stats(),
1093
- format_history_table(),
1094
- get_history_gallery()
1095
- )
1096
-
1097
- btn_refresh.click(
1098
- fn=refresh_all,
1099
- outputs=[stat_total, stat_completed, stat_errors, stat_success, history_table, history_gallery]
1100
- )
1101
-
1102
- btn_export.click(
1103
- fn=export_history_json,
1104
- outputs=[export_file]
1105
- )
1106
-
1107
- btn_clear_all.click(
1108
- fn=clear_all_history,
1109
- outputs=[clear_status]
1110
- ).then(
1111
- fn=refresh_all,
1112
- outputs=[stat_total, stat_completed, stat_errors, stat_success, history_table, history_gallery]
1113
- )
1114
-
1115
- def filter_and_display(keyword, ftype, fstatus):
1116
- filtered = search_history(keyword, ftype, fstatus)
1117
- if not filtered:
1118
- return "<p style='text-align:center; color:#666; padding:40px;'>🔍 Không tìm thấy kết quả phù hợp</p>"
1119
- return format_history_table()
1120
-
1121
- search_input.change(
1122
- fn=filter_and_display,
1123
- inputs=[search_input, filter_type, filter_status],
1124
- outputs=[history_table]
1125
- )
1126
-
1127
- filter_type.change(
1128
- fn=filter_and_display,
1129
- inputs=[search_input, filter_type, filter_status],
1130
- outputs=[history_table]
1131
- )
1132
 
1133
- filter_status.change(
1134
- fn=filter_and_display,
1135
- inputs=[search_input, filter_type, filter_status],
1136
- outputs=[history_table]
1137
  )
1138
-
1139
- # Footer
1140
- gr.Markdown("""
1141
- ---
1142
- **SAM3: Segment Anything Model 3** | Powered by DiffusionWave | Background Processing Enabled | No Timeout Limits
1143
- """)
 
 
 
 
1144
 
1145
  if __name__ == "__main__":
1146
- print("🚀 Starting SAM3 Application...")
1147
- print(f"📁 Output directory: {OUTPUTS_DIR}")
1148
- print(f"📥 Downloads directory: {DOWNLOADS_DIR}")
1149
- print(f"📊 History file: {HISTORY_FILE}")
1150
-
1151
  demo.launch(
1152
- server_name="0.0.0.0",
1153
- server_port=7860,
1154
- max_threads=10,
1155
- show_error=True,
1156
- share=False,
1157
  css=custom_css,
1158
- theme=app_theme
 
 
 
1159
  )
 
1
  import os
2
  import cv2
3
+ import tempfile
4
  import spaces
5
  import gradio as gr
6
  import numpy as np
7
  import torch
8
  import matplotlib
9
+ import matplotlib.pyplot as plt
10
  from PIL import Image, ImageDraw
11
  from typing import Iterable
12
  from gradio.themes import Soft
 
21
  import threading
22
  import queue
23
  import uuid
 
 
24
 
25
  # ============ THEME SETUP ============
26
  colors.steel_blue = colors.Color(
27
  name="steel_blue",
28
+ c50="#EBF3F8",
29
+ c100="#D3E5F0",
30
+ c200="#A8CCE1",
31
+ c300="#7DB3D2",
32
+ c400="#529AC3",
33
+ c500="#4682B4",
34
+ c600="#3E72A0",
35
+ c700="#36638C",
36
+ c800="#2E5378",
37
+ c900="#264364",
38
+ c950="#1E3450",
39
  )
40
 
41
  class CustomBlueTheme(Soft):
42
+ def __init__(
43
+ self,
44
+ *,
45
+ primary_hue: colors.Color | str = colors.gray,
46
+ secondary_hue: colors.Color | str = colors.steel_blue,
47
+ neutral_hue: colors.Color | str = colors.slate,
48
+ text_size: sizes.Size | str = sizes.text_lg,
49
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
50
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
51
+ ),
52
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
53
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
54
+ ),
55
+ ):
56
+ super().__init__(
57
+ primary_hue=primary_hue,
58
+ secondary_hue=secondary_hue,
59
+ neutral_hue=neutral_hue,
60
+ text_size=text_size,
61
+ font=font,
62
+ font_mono=font_mono,
63
+ )
64
  super().set(
65
  background_fill_primary="*primary_50",
66
  background_fill_primary_dark="*primary_900",
 
86
  app_theme = CustomBlueTheme()
87
 
88
  # ============ GLOBAL SETUP ============
89
+ device = "cuda" if torch.cuda.is_available() else "cpu"
90
+ print(f"🖥️ Using compute device: {device}")
91
 
92
+ # History storage
93
  HISTORY_DIR = "processing_history"
94
+ os.makedirs(HISTORY_DIR, exist_ok=True)
 
 
 
95
  HISTORY_FILE = os.path.join(HISTORY_DIR, "history.json")
96
 
97
+ # Background processing queue
98
  processing_queue = queue.Queue()
99
  processing_results = {}
100
 
101
  # Load models
102
+ print("⏳ Loading SAM3 Models permanently into memory...")
103
  try:
104
+ print(" ... Loading Image Text Model")
105
  IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3").to(device)
106
  IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3")
107
+
108
+ print(" ... Loading Image Tracker Model")
109
  TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3").to(device)
110
  TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3")
111
+
112
+ print(" ... Loading Video Model")
113
+ VID_MODEL = Sam3VideoModel.from_pretrained("DiffusionWave/sam3").to(device, dtype=torch.bfloat16)
114
  VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("DiffusionWave/sam3")
115
 
116
+ print("✅ All Models loaded successfully!")
117
  except Exception as e:
118
+ print(f"❌ CRITICAL ERROR LOADING MODELS: {e}")
119
  IMG_MODEL = IMG_PROCESSOR = TRK_MODEL = TRK_PROCESSOR = VID_MODEL = VID_PROCESSOR = None
120
 
121
  # ============ HISTORY MANAGEMENT ============
122
  def load_history():
123
+ """Load processing history from JSON file"""
124
  if os.path.exists(HISTORY_FILE):
125
  try:
126
+ with open(HISTORY_FILE, 'r') as f:
127
  return json.load(f)
128
  except:
129
  return []
130
  return []
131
 
132
+ def save_history(history_item):
133
+ """Save a new history item"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  history = load_history()
135
+ history.insert(0, history_item) # Add to beginning
136
+ history = history[:100] # Keep last 100 items
137
+ with open(HISTORY_FILE, 'w') as f:
138
+ json.dump(history, f, indent=2)
 
 
 
139
 
140
+ def get_history_display():
141
+ """Format history for display"""
142
  history = load_history()
143
  if not history:
144
+ return "Chưa có lịch sử xử lý nào"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ display_text = ""
147
+ for i, item in enumerate(history[:50], 1):
148
+ status_emoji = "✅" if item['status'] == 'completed' else "❌"
149
+ display_text += f"{status_emoji} **{item['type'].upper()}** - {item['timestamp']}\n"
150
+ display_text += f" Prompt: {item['prompt']}\n"
 
 
 
 
 
 
151
  if item.get('output_path'):
152
+ display_text += f" File: `{os.path.basename(item['output_path'])}`\n"
153
+ display_text += "\n"
154
+ return display_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ # ============ UTILITY FUNCTIONS ============
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def apply_mask_overlay(base_image, mask_data, opacity=0.5):
158
+ """Draws segmentation masks on top of an image."""
159
  if isinstance(base_image, np.ndarray):
160
  base_image = Image.fromarray(base_image)
161
  base_image = base_image.convert("RGBA")
 
167
  mask_data = mask_data.cpu().numpy()
168
  mask_data = mask_data.astype(np.uint8)
169
 
170
+ if mask_data.ndim == 4: mask_data = mask_data[0]
171
+ if mask_data.ndim == 3 and mask_data.shape[0] == 1: mask_data = mask_data[0]
 
 
172
 
173
  num_masks = mask_data.shape[0] if mask_data.ndim == 3 else 1
174
  if mask_data.ndim == 2:
 
177
 
178
  try:
179
  color_map = matplotlib.colormaps["rainbow"].resampled(max(num_masks, 1))
180
+ except AttributeError:
181
  import matplotlib.cm as cm
182
  color_map = cm.get_cmap("rainbow").resampled(max(num_masks, 1))
183
 
184
  rgb_colors = [tuple(int(c * 255) for c in color_map(i)[:3]) for i in range(num_masks)]
185
  composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0))
186
 
187
+ for i, single_mask in enumerate(mask_data):
188
+ mask_bitmap = Image.fromarray((single_mask * 255).astype(np.uint8))
189
+ if mask_bitmap.size != base_image.size:
190
+ mask_bitmap = mask_bitmap.resize(base_image.size, resample=Image.NEAREST)
191
 
192
+ fill_color = rgb_colors[i]
193
+ color_fill = Image.new("RGBA", base_image.size, fill_color + (0,))
194
+ mask_alpha = mask_bitmap.point(lambda v: int(v * opacity) if v > 0 else 0)
195
  color_fill.putalpha(mask_alpha)
196
  composite_layer = Image.alpha_composite(composite_layer, color_fill)
197
 
198
  return Image.alpha_composite(base_image, composite_layer).convert("RGB")
199
 
200
  def draw_points_on_image(image, points):
201
+ """Draws red dots on the image to indicate click locations."""
202
  if isinstance(image, np.ndarray):
203
  image = Image.fromarray(image)
204
+
205
  draw_img = image.copy()
206
  draw = ImageDraw.Draw(draw_img)
207
+
208
+ for pt in points:
209
+ x, y = pt
210
  r = 8
211
  draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white", width=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
+ return draw_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
+ # ============ BACKGROUND PROCESSING WORKER ============
216
  def background_worker():
217
+ """Background thread that processes jobs from queue"""
218
  while True:
219
  try:
220
  job = processing_queue.get()
 
224
  job_id = job['id']
225
  job_type = job['type']
226
 
 
 
227
  processing_results[job_id] = {'status': 'processing', 'progress': 0}
228
 
229
  try:
 
240
  'progress': 100
241
  }
242
 
243
+ # Save to history
 
244
  save_history({
245
  'id': job_id,
246
  'type': job_type,
247
  'prompt': job.get('prompt', 'N/A'),
248
  'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
249
  'status': 'completed',
250
+ 'output_path': result.get('output_path')
251
  })
252
 
253
  except Exception as e:
 
 
 
 
254
  processing_results[job_id] = {
255
  'status': 'error',
256
  'error': str(e),
 
265
  'error': str(e)
266
  })
267
  except Exception as e:
268
+ print(f"Worker error: {e}")
269
+
270
+ # Start background worker
271
+ worker_thread = threading.Thread(target=background_worker, daemon=True)
272
+ worker_thread.start()
273
+
274
+ # ============ JOB PROCESSORS ============
275
+ @spaces.GPU
276
+ def process_image_job(job):
277
+ """Process image segmentation job"""
278
+ source_img = job['image']
279
+ text_query = job['prompt']
280
+ conf_thresh = job.get('conf_thresh', 0.5)
281
+
282
+ if isinstance(source_img, str):
283
+ source_img = Image.open(source_img)
284
+
285
+ pil_image = source_img.convert("RGB")
286
+ model_inputs = IMG_PROCESSOR(images=pil_image, text=text_query, return_tensors="pt").to(device)
287
+
288
+ with torch.no_grad():
289
+ inference_output = IMG_MODEL(**model_inputs)
290
+
291
+ processed_results = IMG_PROCESSOR.post_process_instance_segmentation(
292
+ inference_output,
293
+ threshold=conf_thresh,
294
+ mask_threshold=0.5,
295
+ target_sizes=model_inputs.get("original_sizes").tolist()
296
+ )[0]
297
+
298
+ annotation_list = []
299
+ raw_masks = processed_results['masks'].cpu().numpy()
300
+ raw_scores = processed_results['scores'].cpu().numpy()
301
+
302
+ for idx, mask_array in enumerate(raw_masks):
303
+ label_str = f"{text_query} ({raw_scores[idx]:.2f})"
304
+ annotation_list.append((mask_array, label_str))
305
+
306
+ # Save output
307
+ output_path = os.path.join(HISTORY_DIR, f"{job['id']}_result.jpg")
308
+ result_img = apply_mask_overlay(pil_image, raw_masks)
309
+ result_img.save(output_path)
310
+
311
+ return {
312
+ 'image': (pil_image, annotation_list),
313
+ 'output_path': output_path
314
+ }
315
+
316
+ @spaces.GPU
317
+ def process_video_job(job):
318
+ """Process video segmentation job"""
319
+ source_vid = job['video']
320
+ text_query = job['prompt']
321
+ frame_limit = job.get('frame_limit', 60)
322
+
323
+ video_cap = cv2.VideoCapture(source_vid)
324
+ vid_fps = video_cap.get(cv2.CAP_PROP_FPS)
325
+ vid_w = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
326
+ vid_h = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
327
+
328
+ video_frames = []
329
+ counter = 0
330
+ while video_cap.isOpened():
331
+ ret, frame = video_cap.read()
332
+ if not ret or (frame_limit > 0 and counter >= frame_limit): break
333
+ video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
334
+ counter += 1
335
+ video_cap.release()
336
+
337
+ session = VID_PROCESSOR.init_video_session(video=video_frames, inference_device=device, dtype=torch.bfloat16)
338
+ session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=text_query)
339
+
340
+ output_path = os.path.join(HISTORY_DIR, f"{job['id']}_result.mp4")
341
+ video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), vid_fps, (vid_w, vid_h))
342
+
343
+ total_frames = len(video_frames)
344
+ for frame_idx, model_out in enumerate(VID_MODEL.propagate_in_video_iterator(inference_session=session, max_frame_num_to_track=total_frames)):
345
+ post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out)
346
+ f_idx = model_out.frame_idx
347
+ original_pil = Image.fromarray(video_frames[f_idx])
348
+
349
+ if 'masks' in post_processed:
350
+ detected_masks = post_processed['masks']
351
+ if detected_masks.ndim == 4: detected_masks = detected_masks.squeeze(1)
352
+ final_frame = apply_mask_overlay(original_pil, detected_masks)
353
+ else:
354
+ final_frame = original_pil
355
+
356
+ video_writer.write(cv2.cvtColor(np.array(final_frame), cv2.COLOR_RGB2BGR))
357
+
358
+ # Update progress
359
+ progress = int((frame_idx + 1) / total_frames * 100)
360
+ processing_results[job['id']]['progress'] = progress
361
+
362
+ video_writer.release()
363
+ return {'output_path': output_path}
364
+
365
+ @spaces.GPU
366
+ def process_click_job(job):
367
+ """Process click segmentation job"""
368
+ input_image = job['image']
369
+ points_state = job['points']
370
+ labels_state = job['labels']
371
+
372
+ if isinstance(input_image, str):
373
+ input_image = Image.open(input_image)
374
+
375
+ input_points = [[points_state]]
376
+ input_labels = [[labels_state]]
377
+
378
+ inputs = TRK_PROCESSOR(images=input_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
379
+
380
+ with torch.no_grad():
381
+ outputs = TRK_MODEL(**inputs, multimask_output=False)
382
+
383
+ masks = TRK_PROCESSOR.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"], binarize=True)[0]
384
+ final_img = apply_mask_overlay(input_image, masks[0])
385
+ final_img = draw_points_on_image(final_img, points_state)
386
+
387
+ output_path = os.path.join(HISTORY_DIR, f"{job['id']}_result.jpg")
388
+ final_img.save(output_path)
389
+
390
+ return {
391
+ 'image': final_img,
392
+ 'output_path': output_path
393
+ }
394
+
395
+ # ============ UI HANDLERS ============
396
+ def submit_image_job(source_img, text_query, conf_thresh):
397
+ """Submit image segmentation job to background queue"""
398
+ if source_img is None or not text_query:
399
+ return None, "❌ Vui lòng cung cấp ảnh và prompt", ""
400
+
401
+ job_id = str(uuid.uuid4())
402
+ job = {
403
+ 'id': job_id,
404
+ 'type': 'image',
405
+ 'image': source_img,
406
+ 'prompt': text_query,
407
+ 'conf_thresh': conf_thresh
408
+ }
409
+
410
+ processing_queue.put(job)
411
+ return None, f"✅ Đã thêm vào hàng chờ (ID: {job_id[:8]}). Đang xử lý...", job_id
412
 
413
+ def check_image_status(job_id):
414
+ """Check status of image processing job"""
415
+ if not job_id or job_id not in processing_results:
416
+ return None, "Không tìm thấy công việc"
417
+
418
+ result = processing_results[job_id]
419
+
420
+ if result['status'] == 'processing':
421
+ return None, f"⏳ Đang xử lý... {result['progress']}%"
422
+ elif result['status'] == 'completed':
423
+ return result['result']['image'], "✅ Hoàn thành!"
424
+ else:
425
+ return None, f"❌ Lỗi: {result.get('error', 'Unknown')}"
426
+
427
+ def submit_video_job(source_vid, text_query, frame_limit, time_limit):
428
+ """Submit video segmentation job to background queue"""
429
+ if not source_vid or not text_query:
430
+ return None, "❌ Vui lòng cung cấp video và prompt", ""
431
+
432
+ job_id = str(uuid.uuid4())
433
+ job = {
434
+ 'id': job_id,
435
+ 'type': 'video',
436
+ 'video': source_vid,
437
+ 'prompt': text_query,
438
+ 'frame_limit': frame_limit,
439
+ 'time_limit': time_limit
440
+ }
441
+
442
+ processing_queue.put(job)
443
+ return None, f"✅ Đã thêm vào hàng chờ (ID: {job_id[:8]}). Đang xử lý...", job_id
444
 
445
+ def check_video_status(job_id):
446
+ """Check status of video processing job"""
447
+ if not job_id or job_id not in processing_results:
448
+ return None, "Không tìm thấy công việc"
449
+
450
+ result = processing_results[job_id]
451
+
452
+ if result['status'] == 'processing':
453
+ return None, f"⏳ Đang xử lý... {result['progress']}%"
454
+ elif result['status'] == 'completed':
455
+ return result['result']['output_path'], "✅ Hoàn thành!"
456
+ else:
457
+ return None, f"❌ Lỗi: {result.get('error', 'Unknown')}"
458
+
459
+ def image_click_handler(image, evt: gr.SelectData, points_state, labels_state):
460
+ """Handle click events for interactive segmentation"""
461
+ x, y = evt.index
462
+
463
+ if points_state is None: points_state = []
464
+ if labels_state is None: labels_state = []
465
+
466
+ points_state.append([x, y])
467
+ labels_state.append(1)
468
+
469
+ # Process immediately (can be changed to background if needed)
470
+ job_id = str(uuid.uuid4())
471
+ job = {
472
+ 'id': job_id,
473
+ 'type': 'click',
474
+ 'image': image,
475
+ 'points': points_state,
476
+ 'labels': labels_state
477
+ }
478
+
479
+ try:
480
+ result = process_click_job(job)
481
+ return result['image'], points_state, labels_state
482
+ except Exception as e:
483
+ print(f"Click error: {e}")
484
+ return image, points_state, labels_state
485
+
486
+ # ============ GRADIO INTERFACE ============
487
+ custom_css="""
488
+ #col-container { margin: 0 auto; max-width: 1200px; }
489
+ #main-title h1 { font-size: 2.1em !important; }
490
+ .history-box { max-height: 600px; overflow-y: auto; }
491
  """
492
 
493
+ with gr.Blocks(css=custom_css, theme=app_theme) as demo:
494
  with gr.Column(elem_id="col-container"):
495
  gr.Markdown("# **SAM3: Segment Anything Model 3** 🚀", elem_id="main-title")
496
+ gr.Markdown("Xử lý ảnh/video với **background processing** - không cần chờ đợi!")
497
+
 
 
 
 
 
 
 
 
 
 
498
  with gr.Tabs():
499
+ # ===== IMAGE SEGMENTATION TAB =====
500
  with gr.Tab("📷 Image Segmentation"):
501
  with gr.Row():
502
  with gr.Column(scale=1):
503
+ image_input = gr.Image(label="Upload Image", type="pil", height=350)
504
+ txt_prompt_img = gr.Textbox(label="Text Prompt", placeholder="e.g., cat, face, car wheel")
505
+ with gr.Accordion("Advanced Settings", open=False):
506
+ conf_slider = gr.Slider(0.0, 1.0, value=0.45, step=0.05, label="Confidence Threshold")
507
+
508
+ btn_submit_img = gr.Button("🚀 Submit Job (Background)", variant="primary")
509
+ btn_check_img = gr.Button("🔍 Check Status", variant="secondary")
510
+ job_id_img = gr.Textbox(label="Job ID", visible=False)
511
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  with gr.Column(scale=1.5):
513
+ image_result = gr.AnnotatedImage(label="Segmented Result", height=410)
514
+ status_img = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
 
516
+ btn_submit_img.click(
517
+ fn=submit_image_job,
518
+ inputs=[image_input, txt_prompt_img, conf_slider],
519
+ outputs=[image_result, status_img, job_id_img]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  )
521
 
522
+ btn_check_img.click(
523
+ fn=check_image_status,
524
+ inputs=[job_id_img],
525
+ outputs=[image_result, status_img]
526
  )
527
 
528
+ # ===== VIDEO SEGMENTATION TAB =====
529
  with gr.Tab("🎥 Video Segmentation"):
530
  with gr.Row():
531
  with gr.Column():
532
+ video_input = gr.Video(label="Upload Video", format="mp4", height=320)
533
+ txt_prompt_vid = gr.Textbox(label="Text Prompt", placeholder="e.g., person running, red car")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
+ with gr.Row():
536
+ frame_limiter = gr.Slider(10, 500, value=60, step=10, label="Max Frames")
537
+ time_limiter = gr.Radio([60, 120, 180], value=60, label="Timeout (seconds)")
538
 
539
+ btn_submit_vid = gr.Button("🚀 Submit Job (Background)", variant="primary")
540
+ btn_check_vid = gr.Button("🔍 Check Status", variant="secondary")
541
+ job_id_vid = gr.Textbox(label="Job ID", visible=False)
 
 
 
 
 
 
 
542
 
543
+ with gr.Column():
544
+ video_result = gr.Video(label="Processed Video")
545
+ status_vid = gr.Textbox(label="Status", interactive=False)
546
+
547
+ btn_submit_vid.click(
548
+ fn=submit_video_job,
549
+ inputs=[video_input, txt_prompt_vid, frame_limiter, time_limiter],
550
+ outputs=[video_result, status_vid, job_id_vid]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  )
552
 
553
+ btn_check_vid.click(
554
+ fn=check_video_status,
555
+ inputs=[job_id_vid],
556
+ outputs=[video_result, status_vid]
557
  )
 
 
 
 
 
 
 
558
 
559
+ # ===== CLICK SEGMENTATION TAB =====
560
+ with gr.Tab("👆 Click Segmentation"):
561
  with gr.Row():
562
  with gr.Column(scale=1):
563
+ img_click_input = gr.Image(type="pil", label="Upload Image", interactive=True, height=450)
564
+ gr.Markdown("**Hướng dẫn:** Click vào đối tượng bạn muốn phân đoạn")
 
 
 
 
 
 
565
 
566
  with gr.Row():
567
+ img_click_clear = gr.Button("🔄 Clear Points & Reset", variant="primary")
 
568
 
569
+ st_click_points = gr.State([])
570
+ st_click_labels = gr.State([])
571
+
572
  with gr.Column(scale=1):
573
+ img_click_output = gr.Image(type="pil", label="Result Preview", height=450, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
+ img_click_input.select(
576
+ image_click_handler,
577
+ inputs=[img_click_input, st_click_points, st_click_labels],
578
+ outputs=[img_click_output, st_click_points, st_click_labels]
 
 
 
 
 
 
 
 
 
 
579
  )
580
 
581
+ img_click_clear.click(
582
+ lambda: (None, [], []),
583
+ outputs=[img_click_output, st_click_points, st_click_labels]
 
584
  )
585
 
586
+ # ===== HISTORY TAB =====
587
+ with gr.Tab("📜 Lịch Sử Xử Lý"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
  with gr.Row():
589
  with gr.Column():
590
+ btn_refresh_history = gr.Button("🔄 Refresh History", variant="primary")
591
+ history_display = gr.Markdown(value=get_history_display(), elem_classes="history-box")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
 
593
+ with gr.Accordion("Hướng dẫn", open=False):
594
+ gr.Markdown("""
595
+ ### Lịch sử lưu:
596
+ - ✅ **Hoàn thành**: File đã được xử lý thành công
597
+ - **Lỗi**: Xử thất bại
598
+ - Tất cả file output được lưu trong thư mục `processing_history/`
599
+ - Hệ thống giữ lại 100 lịch sử gần nhất
600
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
 
602
+ btn_refresh_history.click(
603
+ fn=get_history_display,
604
+ outputs=[history_display]
 
605
  )
606
+
607
+ # ===== BATCH PROCESSING TAB =====
608
+ with gr.Tab("⚙️ Batch Processing"):
609
+ gr.Markdown("### Xử lý hàng loạt (Coming Soon)")
610
+ gr.Markdown("""
611
+ Tính năng này sẽ cho phép bạn:
612
+ - Upload nhiều ảnh/video cùng lúc
613
+ - Tự động xử lý tuần tự
614
+ - Download tất cả kết quả dưới dạng ZIP
615
+ """)
616
 
617
  if __name__ == "__main__":
 
 
 
 
 
618
  demo.launch(
 
 
 
 
 
619
  css=custom_css,
620
+ theme=app_theme,
621
+ ssr_mode=False,
622
+ mcp_server=True,
623
+ show_error=True
624
  )