Translsis commited on
Commit
7cae504
·
verified ·
1 Parent(s): 25cc101

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +591 -619
app.py CHANGED
@@ -1,12 +1,10 @@
1
  import os
2
  import cv2
3
- import tempfile
4
  import spaces
5
  import gradio as gr
6
  import numpy as np
7
  import torch
8
  import matplotlib
9
- import matplotlib.pyplot as plt
10
  from PIL import Image, ImageDraw
11
  from typing import Iterable
12
  from gradio.themes import Soft
@@ -22,46 +20,23 @@ import threading
22
  import queue
23
  import uuid
24
  import shutil
 
25
 
26
  # ============ THEME SETUP ============
27
  colors.steel_blue = colors.Color(
28
  name="steel_blue",
29
- c50="#EBF3F8",
30
- c100="#D3E5F0",
31
- c200="#A8CCE1",
32
- c300="#7DB3D2",
33
- c400="#529AC3",
34
- c500="#4682B4",
35
- c600="#3E72A0",
36
- c700="#36638C",
37
- c800="#2E5378",
38
- c900="#264364",
39
- c950="#1E3450",
40
  )
41
 
42
  class CustomBlueTheme(Soft):
43
- def __init__(
44
- self,
45
- *,
46
- primary_hue: colors.Color | str = colors.gray,
47
- secondary_hue: colors.Color | str = colors.steel_blue,
48
- neutral_hue: colors.Color | str = colors.slate,
49
- text_size: sizes.Size | str = sizes.text_lg,
50
- font: fonts.Font | str | Iterable[fonts.Font | str] = (
51
- fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
52
- ),
53
- font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
54
- fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
55
- ),
56
- ):
57
- super().__init__(
58
- primary_hue=primary_hue,
59
- secondary_hue=secondary_hue,
60
- neutral_hue=neutral_hue,
61
- text_size=text_size,
62
- font=font,
63
- font_mono=font_mono,
64
- )
65
  super().set(
66
  background_fill_primary="*primary_50",
67
  background_fill_primary_dark="*primary_900",
@@ -88,41 +63,40 @@ app_theme = CustomBlueTheme()
88
 
89
  # ============ GLOBAL SETUP ============
90
  device = "cuda" if torch.cuda.is_available() else "cpu"
91
- print(f"🖥️ Using compute device: {device}")
92
 
93
- # History storage
94
  HISTORY_DIR = "processing_history"
95
  OUTPUTS_DIR = os.path.join(HISTORY_DIR, "outputs")
 
96
  os.makedirs(OUTPUTS_DIR, exist_ok=True)
 
97
  HISTORY_FILE = os.path.join(HISTORY_DIR, "history.json")
98
 
99
- # Background processing queue
100
  processing_queue = queue.Queue()
101
  processing_results = {}
102
 
103
  # Load models
104
- print("⏳ Loading SAM3 Models permanently into memory...")
105
  try:
106
- print(" ... Loading Image Text Model")
107
  IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3").to(device)
108
  IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3")
109
-
110
- print(" ... Loading Image Tracker Model")
111
  TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3").to(device)
112
  TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3")
113
-
114
- print(" ... Loading Video Model")
115
  VID_MODEL = Sam3VideoModel.from_pretrained("DiffusionWave/sam3").to(device, dtype=torch.bfloat16)
116
  VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("DiffusionWave/sam3")
117
 
118
- print("✅ All Models loaded successfully!")
119
  except Exception as e:
120
- print(f"❌ CRITICAL ERROR LOADING MODELS: {e}")
121
  IMG_MODEL = IMG_PROCESSOR = TRK_MODEL = TRK_PROCESSOR = VID_MODEL = VID_PROCESSOR = None
122
 
123
  # ============ HISTORY MANAGEMENT ============
124
  def load_history():
125
- """Load processing history from JSON file"""
126
  if os.path.exists(HISTORY_FILE):
127
  try:
128
  with open(HISTORY_FILE, 'r', encoding='utf-8') as f:
@@ -131,26 +105,22 @@ def load_history():
131
  return []
132
  return []
133
 
134
- def save_history(history_item):
135
- """Save a new history item"""
136
  history = load_history()
137
- history.insert(0, history_item)
138
- history = history[:200] # Keep last 200 items
139
  with open(HISTORY_FILE, 'w', encoding='utf-8') as f:
140
  json.dump(history, f, indent=2, ensure_ascii=False)
141
 
142
  def get_history_stats():
143
- """Get statistics from history"""
144
  history = load_history()
145
  total = len(history)
146
  completed = sum(1 for h in history if h['status'] == 'completed')
147
  errors = sum(1 for h in history if h['status'] == 'error')
148
-
149
  types = {}
150
  for h in history:
151
  t = h['type']
152
  types[t] = types.get(t, 0) + 1
153
-
154
  return {
155
  'total': total,
156
  'completed': completed,
@@ -159,41 +129,90 @@ def get_history_stats():
159
  'types': types
160
  }
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def format_history_table():
163
- """Format history as HTML table"""
164
  history = load_history()
165
  if not history:
166
- return "<p style='text-align:center; color:#666;'>Chưa có lịch sử xử lý nào</p>"
167
 
168
  html = """
169
  <style>
170
- .history-table { width: 100%; border-collapse: collapse; font-size: 14px; }
171
- .history-table th { background: linear-gradient(90deg, #4682B4, #529AC3); color: white; padding: 12px; text-align: left; font-weight: 600; }
172
- .history-table td { padding: 10px; border-bottom: 1px solid #ddd; }
173
- .history-table tr:hover { background-color: #f5f5f5; }
174
- .status-badge { padding: 4px 10px; border-radius: 12px; font-size: 12px; font-weight: 600; }
175
- .status-completed { background: #d4edda; color: #155724; }
176
- .status-error { background: #f8d7da; color: #721c24; }
177
- .status-processing { background: #fff3cd; color: #856404; }
178
- .type-badge { padding: 3px 8px; border-radius: 8px; font-size: 11px; font-weight: 600; background: #e3f2fd; color: #1976d2; }
179
- .action-btn { padding: 5px 12px; margin: 2px; border: none; border-radius: 6px; cursor: pointer; font-size: 12px; font-weight: 600; }
180
- .btn-download { background: #28a745; color: white; }
181
- .btn-delete { background: #dc3545; color: white; }
182
- .btn-download:hover { background: #218838; }
183
- .btn-delete:hover { background: #c82333; }
184
- .prompt-text { max-width: 300px; overflow: hidden; text-overflow: ellipsis; white-space: nowrap; }
185
- .file-count { font-size: 11px; color: #666; margin-top: 3px; }
186
  </style>
187
  <table class='history-table'>
188
  <thead>
189
  <tr>
190
- <th style='width: 40px;'>#</th>
191
- <th style='width: 80px;'>Loại</th>
192
- <th style='width: 100px;'>Trạng thái</th>
 
193
  <th>Prompt</th>
194
- <th style='width: 100px;'>Files</th>
195
- <th style='width: 150px;'>Thời gian</th>
196
- <th style='width: 150px;'>Thao tác</th>
197
  </tr>
198
  </thead>
199
  <tbody>
@@ -201,128 +220,87 @@ def format_history_table():
201
 
202
  for i, item in enumerate(history[:100], 1):
203
  status_class = f"status-{item['status']}"
204
- status_text = "✅ Hoàn thành" if item['status'] == 'completed' else "❌ Lỗi" if item['status'] == 'error' else "⏳ Đang xử lý"
205
 
206
  type_icons = {'image': '📷', 'video': '🎥', 'click': '👆'}
207
  type_icon = type_icons.get(item['type'], '📄')
208
 
209
- prompt = item.get('prompt', 'N/A')[:50] + ('...' if len(item.get('prompt', '')) > 50 else '')
 
210
 
211
- # Count files
212
  file_info = []
213
  if item.get('output_path'):
214
- file_info.append("Overlay")
215
  if item.get('segmented_files'):
216
- file_info.append(f"{len(item['segmented_files'])} Objects")
217
  if item.get('mask_video_path'):
218
- file_info.append("Masks")
219
  if item.get('segmented_video_path'):
220
- file_info.append("Segmented")
221
-
222
- files_text = "<br>".join(file_info) if file_info else "N/A"
223
 
224
- download_btn = ""
225
- if item.get('output_path') or item.get('segmented_files'):
226
- download_btn = f"<button class='action-btn btn-download' onclick='downloadFiles(\"{item['id']}\")'>📥 Download</button>"
227
-
228
- delete_btn = f"<button class='action-btn btn-delete' onclick='deleteHistory(\"{item['id']}\")'>🗑️ Xóa</button>"
229
 
230
  html += f"""
231
  <tr>
232
- <td>{i}</td>
 
233
  <td><span class='type-badge'>{type_icon} {item['type'].upper()}</span></td>
234
  <td><span class='status-badge {status_class}'>{status_text}</span></td>
235
- <td class='prompt-text' title='{item.get("prompt", "N/A")}'>{prompt}</td>
236
  <td><div class='file-count'>{files_text}</div></td>
237
- <td>{item['timestamp']}<br><small>{item.get('duration', '')}</small></td>
238
- <td>{download_btn}{delete_btn}</td>
 
 
239
  </tr>
240
  """
241
 
242
  html += """
243
  </tbody>
244
  </table>
245
- <script>
246
- function downloadFiles(id) {
247
- alert('Download functionality: ' + id + '\\nFiles will be packaged as ZIP');
248
- }
249
- function deleteHistory(id) {
250
- if(confirm('Bạn có chắc muốn xóa mục này?')) {
251
- alert('Deleted: ' + id);
252
- }
253
- }
254
- </script>
255
  """
256
 
257
  return html
258
 
259
  def get_history_gallery():
260
- """Get recent outputs for gallery display"""
261
  history = load_history()
262
  gallery_items = []
263
 
264
- for item in history[:20]:
265
- if item['status'] == 'completed' and item.get('output_path'):
266
- output_path = item['output_path']
267
- if os.path.exists(output_path):
268
- caption = f"{item['type'].upper()} | {item['prompt'][:30]}... | {item['timestamp']}"
269
- gallery_items.append((output_path, caption))
270
 
271
- return gallery_items
272
 
273
  def search_history(keyword, filter_type, filter_status):
274
- """Search and filter history"""
275
  history = load_history()
276
  filtered = history
277
 
278
  if keyword:
279
  filtered = [h for h in filtered if keyword.lower() in h.get('prompt', '').lower()]
280
-
281
  if filter_type and filter_type != "all":
282
  filtered = [h for h in filtered if h['type'] == filter_type]
283
-
284
  if filter_status and filter_status != "all":
285
  filtered = [h for h in filtered if h['status'] == filter_status]
286
 
287
  return filtered
288
 
289
- def delete_history_item(item_id):
290
- """Delete a history item and its output file"""
291
- history = load_history()
292
- updated_history = []
293
- deleted = False
294
-
295
- for item in history:
296
- if item['id'] == item_id:
297
- # Delete output file if exists
298
- if item.get('output_path') and os.path.exists(item['output_path']):
299
- try:
300
- os.remove(item['output_path'])
301
- except:
302
- pass
303
- deleted = True
304
- else:
305
- updated_history.append(item)
306
-
307
- if deleted:
308
- with open(HISTORY_FILE, 'w', encoding='utf-8') as f:
309
- json.dump(updated_history, f, indent=2, ensure_ascii=False)
310
- return "✅ Đã xóa thành công"
311
- return "❌ Không tìm thấy mục cần xóa"
312
-
313
  def clear_all_history():
314
- """Clear all history and output files"""
315
  if os.path.exists(OUTPUTS_DIR):
316
  shutil.rmtree(OUTPUTS_DIR)
317
  os.makedirs(OUTPUTS_DIR)
 
 
 
318
 
319
  with open(HISTORY_FILE, 'w', encoding='utf-8') as f:
320
  json.dump([], f)
321
 
322
- return "✅ Đã xóa toàn bộ lịch sử"
323
 
324
  def export_history_json():
325
- """Export history as downloadable JSON"""
326
  history = load_history()
327
  export_path = os.path.join(HISTORY_DIR, f"history_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
328
 
@@ -331,9 +309,8 @@ def export_history_json():
331
 
332
  return export_path
333
 
334
- # ============ UTILITY FUNCTIONS ============
335
  def apply_mask_overlay(base_image, mask_data, opacity=0.5):
336
- """Draws segmentation masks on top of an image."""
337
  if isinstance(base_image, np.ndarray):
338
  base_image = Image.fromarray(base_image)
339
  base_image = base_image.convert("RGBA")
@@ -345,8 +322,10 @@ def apply_mask_overlay(base_image, mask_data, opacity=0.5):
345
  mask_data = mask_data.cpu().numpy()
346
  mask_data = mask_data.astype(np.uint8)
347
 
348
- if mask_data.ndim == 4: mask_data = mask_data[0]
349
- if mask_data.ndim == 3 and mask_data.shape[0] == 1: mask_data = mask_data[0]
 
 
350
 
351
  num_masks = mask_data.shape[0] if mask_data.ndim == 3 else 1
352
  if mask_data.ndim == 2:
@@ -355,610 +334,603 @@ def apply_mask_overlay(base_image, mask_data, opacity=0.5):
355
 
356
  try:
357
  color_map = matplotlib.colormaps["rainbow"].resampled(max(num_masks, 1))
358
- except AttributeError:
359
  import matplotlib.cm as cm
360
  color_map = cm.get_cmap("rainbow").resampled(max(num_masks, 1))
361
 
362
  rgb_colors = [tuple(int(c * 255) for c in color_map(i)[:3]) for i in range(num_masks)]
363
  composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0))
364
 
365
- for i, single_mask in enumerate(mask_data):
366
- mask_bitmap = Image.fromarray((single_mask * 255).astype(np.uint8))
367
- if mask_bitmap.size != base_image.size:
368
- mask_bitmap = mask_bitmap.resize(base_image.size, resample=Image.NEAREST)
369
 
370
- fill_color = rgb_colors[i]
371
- color_fill = Image.new("RGBA", base_image.size, fill_color + (0,))
372
- mask_alpha = mask_bitmap.point(lambda v: int(v * opacity) if v > 0 else 0)
373
  color_fill.putalpha(mask_alpha)
374
  composite_layer = Image.alpha_composite(composite_layer, color_fill)
375
 
376
  return Image.alpha_composite(base_image, composite_layer).convert("RGB")
377
 
378
  def draw_points_on_image(image, points):
379
- """Draws red dots on the image to indicate click locations."""
380
  if isinstance(image, np.ndarray):
381
  image = Image.fromarray(image)
382
-
383
  draw_img = image.copy()
384
  draw = ImageDraw.Draw(draw_img)
385
-
386
- for pt in points:
387
- x, y = pt
388
  r = 8
389
  draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white", width=4)
390
-
391
  return draw_img
392
 
393
- # ============ BACKGROUND PROCESSING WORKER ============
394
- def background_worker():
395
- """Background thread that processes jobs from queue"""
396
- while True:
397
- try:
398
- job = processing_queue.get()
399
- if job is None:
400
- break
401
-
402
- job_id = job['id']
403
- job_type = job['type']
404
-
405
- processing_results[job_id] = {'status': 'processing', 'progress': 0}
406
-
407
- try:
408
- if job_type == 'image':
409
- result = process_image_job(job)
410
- elif job_type == 'video':
411
- result = process_video_job(job)
412
- elif job_type == 'click':
413
- result = process_click_job(job)
414
-
415
- processing_results[job_id] = {
416
- 'status': 'completed',
417
- 'result': result,
418
- 'progress': 100
419
- }
420
-
421
- save_history({
422
- 'id': job_id,
423
- 'type': job_type,
424
- 'prompt': job.get('prompt', 'N/A'),
425
- 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
426
- 'status': 'completed',
427
- 'output_path': result.get('output_path'),
428
- 'segmented_files': result.get('segmented_files', []),
429
- 'mask_video_path': result.get('mask_video_path'),
430
- 'segmented_video_path': result.get('segmented_video_path'),
431
- 'num_objects': result.get('num_objects', 0),
432
- 'duration': result.get('duration', 'N/A')
433
- })
434
-
435
- except Exception as e:
436
- processing_results[job_id] = {
437
- 'status': 'error',
438
- 'error': str(e),
439
- 'progress': 0
440
- }
441
- save_history({
442
- 'id': job_id,
443
- 'type': job_type,
444
- 'prompt': job.get('prompt', 'N/A'),
445
- 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
446
- 'status': 'error',
447
- 'error': str(e)
448
- })
449
- except Exception as e:
450
- print(f"Worker error: {e}")
451
-
452
- worker_thread = threading.Thread(target=background_worker, daemon=True)
453
- worker_thread.start()
454
-
455
  # ============ JOB PROCESSORS ============
456
  @spaces.GPU
457
  def process_image_job(job):
458
- """Process image segmentation job"""
459
- start_time = datetime.now()
460
- source_img = job['image']
461
- text_query = job['prompt']
462
- conf_thresh = job.get('conf_thresh', 0.5)
463
 
464
- if isinstance(source_img, str):
465
- source_img = Image.open(source_img)
466
 
467
- pil_image = source_img.convert("RGB")
468
- model_inputs = IMG_PROCESSOR(images=pil_image, text=text_query, return_tensors="pt").to(device)
469
-
470
  with torch.no_grad():
471
- inference_output = IMG_MODEL(**model_inputs)
472
-
473
- processed_results = IMG_PROCESSOR.post_process_instance_segmentation(
474
- inference_output,
475
- threshold=conf_thresh,
476
  mask_threshold=0.5,
477
- target_sizes=model_inputs.get("original_sizes").tolist()
478
  )[0]
479
-
480
- annotation_list = []
481
- raw_masks = processed_results['masks'].cpu().numpy()
482
- raw_scores = processed_results['scores'].cpu().numpy()
483
 
484
- for idx, mask_array in enumerate(raw_masks):
485
- label_str = f"{text_query} ({raw_scores[idx]:.2f})"
486
- annotation_list.append((mask_array, label_str))
487
 
488
- # Save overlay result
489
- output_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_overlay.jpg")
490
- result_img = apply_mask_overlay(pil_image, raw_masks)
491
- result_img.save(output_path)
492
 
493
- # Extract and save individual segmented objects
494
- segmented_files = []
495
- for idx, mask_array in enumerate(raw_masks):
496
- # Create transparent background for segmented object
497
- mask_bool = mask_array.astype(bool)
498
-
499
- # Create RGBA image
500
- segmented = Image.new("RGBA", pil_image.size, (0, 0, 0, 0))
501
- img_array = np.array(pil_image.convert("RGBA"))
502
-
503
- # Apply mask
504
- img_array[~mask_bool] = [0, 0, 0, 0]
505
- segmented = Image.fromarray(img_array)
506
-
507
- # Crop to bounding box to save space
508
- bbox = Image.fromarray(mask_array * 255).getbbox()
509
  if bbox:
510
- segmented_cropped = segmented.crop(bbox)
511
- seg_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_object_{idx+1}.png")
512
- segmented_cropped.save(seg_path)
513
- segmented_files.append(seg_path)
514
-
515
- duration = (datetime.now() - start_time).total_seconds()
516
 
517
  return {
518
- 'image': (pil_image, annotation_list),
519
- 'output_path': output_path,
520
- 'segmented_files': segmented_files,
521
- 'num_objects': len(segmented_files),
522
- 'duration': f"{duration:.2f}s"
523
  }
524
 
525
  @spaces.GPU
526
  def process_video_job(job):
527
- """Process video segmentation job"""
528
- start_time = datetime.now()
529
- source_vid = job['video']
530
- text_query = job['prompt']
531
- frame_limit = job.get('frame_limit', 60)
532
 
533
- video_cap = cv2.VideoCapture(source_vid)
534
- vid_fps = video_cap.get(cv2.CAP_PROP_FPS)
535
- vid_w = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
536
- vid_h = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
 
 
 
 
 
537
 
538
- video_frames = []
539
- counter = 0
540
- while video_cap.isOpened():
541
- ret, frame = video_cap.read()
542
- if not ret or (frame_limit > 0 and counter >= frame_limit): break
543
- video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
544
- counter += 1
545
- video_cap.release()
546
 
547
- session = VID_PROCESSOR.init_video_session(video=video_frames, inference_device=device, dtype=torch.bfloat16)
548
- session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=text_query)
 
549
 
550
- # Overlay video
551
- output_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_overlay.mp4")
552
- video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), vid_fps, (vid_w, vid_h))
 
 
553
 
554
- # Mask-only video (black background with white masks)
555
- mask_video_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_masks_only.mp4")
556
- mask_writer = cv2.VideoWriter(mask_video_path, cv2.VideoWriter_fourcc(*'mp4v'), vid_fps, (vid_w, vid_h))
557
-
558
- # Segmented objects video (transparent background)
559
- segmented_video_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_segmented.mp4")
560
- segmented_writer = cv2.VideoWriter(segmented_video_path, cv2.VideoWriter_fourcc(*'mp4v'), vid_fps, (vid_w, vid_h))
561
-
562
- total_frames = len(video_frames)
563
- for frame_idx, model_out in enumerate(VID_MODEL.propagate_in_video_iterator(inference_session=session, max_frame_num_to_track=total_frames)):
564
- post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out)
565
- f_idx = model_out.frame_idx
566
- original_pil = Image.fromarray(video_frames[f_idx])
567
 
568
- if 'masks' in post_processed:
569
- detected_masks = post_processed['masks']
570
- if detected_masks.ndim == 4: detected_masks = detected_masks.squeeze(1)
571
-
572
- # 1. Overlay frame
573
- overlay_frame = apply_mask_overlay(original_pil, detected_masks)
574
- video_writer.write(cv2.cvtColor(np.array(overlay_frame), cv2.COLOR_RGB2BGR))
575
-
576
- # 2. Mask-only frame (white masks on black background)
577
- mask_frame = np.zeros((vid_h, vid_w, 3), dtype=np.uint8)
578
- if isinstance(detected_masks, torch.Tensor):
579
- detected_masks_np = detected_masks.cpu().numpy()
580
- else:
581
- detected_masks_np = detected_masks
582
-
583
- # Combine all masks
584
- combined_mask = np.zeros((vid_h, vid_w), dtype=np.uint8)
585
- for mask in detected_masks_np:
586
- if mask.shape != (vid_h, vid_w):
587
- mask = cv2.resize(mask.astype(np.uint8), (vid_w, vid_h), interpolation=cv2.INTER_NEAREST)
588
- combined_mask = np.maximum(combined_mask, mask)
589
 
590
- mask_frame[combined_mask > 0] = [255, 255, 255]
591
- mask_writer.write(mask_frame)
592
 
593
- # 3. Segmented frame (original with background removed)
594
- segmented_frame = np.array(original_pil.convert("RGBA"))
595
- alpha_mask = (combined_mask * 255).astype(np.uint8)
596
- segmented_frame[:, :, 3] = alpha_mask
 
 
597
 
598
- # Convert to BGR for video (with green screen for transparency)
599
- bgr_frame = np.zeros((vid_h, vid_w, 3), dtype=np.uint8)
600
- bgr_frame[:, :] = [0, 255, 0] # Green background
601
 
 
 
 
 
602
  for c in range(3):
603
- bgr_frame[:, :, c] = np.where(
604
- combined_mask > 0,
605
- segmented_frame[:, :, 2-c], # RGB to BGR
606
- bgr_frame[:, :, c]
607
- )
608
-
609
- segmented_writer.write(bgr_frame)
610
- else:
611
- # No masks detected, write original frames
612
- video_writer.write(cv2.cvtColor(np.array(original_pil), cv2.COLOR_RGB2BGR))
613
- mask_writer.write(np.zeros((vid_h, vid_w, 3), dtype=np.uint8))
614
- segmented_writer.write(cv2.cvtColor(np.array(original_pil), cv2.COLOR_RGB2BGR))
615
-
616
- progress = int((frame_idx + 1) / total_frames * 100)
617
- processing_results[job['id']]['progress'] = progress
618
 
619
- video_writer.release()
620
- mask_writer.release()
621
- segmented_writer.release()
622
 
623
- duration = (datetime.now() - start_time).total_seconds()
 
624
 
625
  return {
626
- 'output_path': output_path,
627
- 'mask_video_path': mask_video_path,
628
- 'segmented_video_path': segmented_video_path,
629
- 'duration': f"{duration:.2f}s"
630
  }
631
 
632
  @spaces.GPU
633
  def process_click_job(job):
634
- """Process click segmentation job"""
635
- start_time = datetime.now()
636
- input_image = job['image']
637
- points_state = job['points']
638
- labels_state = job['labels']
639
-
640
- if isinstance(input_image, str):
641
- input_image = Image.open(input_image)
642
 
643
- input_points = [[points_state]]
644
- input_labels = [[labels_state]]
645
-
646
- inputs = TRK_PROCESSOR(images=input_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
 
 
647
 
648
  with torch.no_grad():
649
  outputs = TRK_MODEL(**inputs, multimask_output=False)
650
-
651
- masks = TRK_PROCESSOR.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"], binarize=True)[0]
652
- final_img = apply_mask_overlay(input_image, masks[0])
653
- final_img = draw_points_on_image(final_img, points_state)
654
-
655
- output_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_result.jpg")
656
- final_img.save(output_path)
657
-
658
- duration = (datetime.now() - start_time).total_seconds()
659
-
660
- return {
661
- 'image': final_img,
662
- 'output_path': output_path,
663
- 'duration': f"{duration:.2f}s"
664
- }
665
-
666
- # ============ UI HANDLERS ============
667
- def submit_image_job(source_img, text_query, conf_thresh):
668
- if source_img is None or not text_query:
669
- return None, "❌ Vui lòng cung cấp ảnh và prompt", ""
670
-
671
- job_id = str(uuid.uuid4())
672
- job = {
673
- 'id': job_id,
674
- 'type': 'image',
675
- 'image': source_img,
676
- 'prompt': text_query,
677
- 'conf_thresh': conf_thresh
678
- }
679
 
680
- processing_queue.put(job)
681
- return None, f"✅ Đã thêm vào hàng chờ (ID: {job_id[:8]}). Đang xử lý...", job_id
682
-
683
- def check_image_status(job_id):
684
- if not job_id or job_id not in processing_results:
685
- return None, "Không tìm thấy công việc"
686
 
687
- result = processing_results[job_id]
 
688
 
689
- if result['status'] == 'processing':
690
- return None, f"⏳ Đang xử lý... {result['progress']}%"
691
- elif result['status'] == 'completed':
692
- return result['result']['image'], "✅ Hoàn thành!"
693
- else:
694
- return None, f"❌ Lỗi: {result.get('error', 'Unknown')}"
695
-
696
- def submit_video_job(source_vid, text_query, frame_limit, time_limit):
697
- if not source_vid or not text_query:
698
- return None, "❌ Vui lòng cung cấp video và prompt", ""
699
 
700
- job_id = str(uuid.uuid4())
701
- job = {
702
- 'id': job_id,
703
- 'type': 'video',
704
- 'video': source_vid,
705
- 'prompt': text_query,
706
- 'frame_limit': frame_limit,
707
- 'time_limit': time_limit
708
  }
709
-
710
- processing_queue.put(job)
711
- return None, f"✅ Đã thêm vào hàng chờ (ID: {job_id[:8]}). Đang xử lý...", job_id
712
 
713
- def check_video_status(job_id):
714
- if not job_id or job_id not in processing_results:
715
- return None, "Không tìm thấy công việc"
716
-
717
- result = processing_results[job_id]
718
-
719
- if result['status'] == 'processing':
720
- return None, f"⏳ Đang xử lý... {result['progress']}%"
721
- elif result['status'] == 'completed':
722
- return result['result']['output_path'], "✅ Hoàn thành!"
723
- else:
724
- return None, f"❌ Lỗi: {result.get('error', 'Unknown')}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
 
726
- def image_click_handler(image, evt: gr.SelectData, points_state, labels_state):
727
- x, y = evt.index
728
-
729
- if points_state is None: points_state = []
730
- if labels_state is None: labels_state = []
731
-
732
- points_state.append([x, y])
733
- labels_state.append(1)
734
-
735
- job_id = str(uuid.uuid4())
736
- job = {
737
- 'id': job_id,
738
- 'type': 'click',
739
- 'image': image,
740
- 'points': points_state,
741
- 'labels': labels_state
742
- }
743
-
744
- try:
745
- result = process_click_job(job)
746
- return result['image'], points_state, labels_state
747
- except Exception as e:
748
- print(f"Click error: {e}")
749
- return image, points_state, labels_state
750
 
751
- # ============ GRADIO INTERFACE ============
752
- custom_css="""
753
- #col-container { margin: 0 auto; max-width: 1300px; }
754
- #main-title h1 { font-size: 2.1em !important; }
755
- .stat-card { padding: 20px; border-radius: 12px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; text-align: center; }
756
- .stat-number { font-size: 2.5em; font-weight: 700; margin: 10px 0; }
757
- .stat-label { font-size: 1.1em; opacity: 0.9; }
758
  """
759
 
760
- with gr.Blocks(css=custom_css, theme=app_theme) as demo:
761
  with gr.Column(elem_id="col-container"):
762
  gr.Markdown("# **SAM3: Segment Anything Model 3** 🚀", elem_id="main-title")
763
- gr.Markdown("Xử lý ảnh/video với **background processing** - không cần chờ đợi!")
764
-
765
  with gr.Tabs():
766
- # ===== IMAGE SEGMENTATION TAB =====
767
  with gr.Tab("📷 Image Segmentation"):
768
  with gr.Row():
769
  with gr.Column(scale=1):
770
- image_input = gr.Image(label="Upload Image", type="pil", height=350)
771
- txt_prompt_img = gr.Textbox(label="Text Prompt", placeholder="e.g., cat, face, car wheel")
772
- with gr.Accordion("Advanced Settings", open=False):
773
- conf_slider = gr.Slider(0.0, 1.0, value=0.45, step=0.05, label="Confidence Threshold")
 
 
 
 
774
 
775
- btn_submit_img = gr.Button("🚀 Submit Job (Background)", variant="primary")
776
- btn_check_img = gr.Button("🔍 Check Status", variant="secondary")
777
- job_id_img = gr.Textbox(label="Job ID", visible=False)
778
-
779
  with gr.Column(scale=1.5):
780
- image_result = gr.AnnotatedImage(label="Segmented Result (Overlay)", height=410)
781
- status_img = gr.Textbox(label="Status", interactive=False)
782
 
783
  with gr.Accordion("📦 Extracted Objects", open=True):
784
- gr.Markdown("**Các đối tượng được tách ra sẽ hiển thị đây:**")
785
- segmented_gallery = gr.Gallery(
786
- label="Segmented Objects (PNG with transparent background)",
787
  columns=3,
788
  height=300,
789
  object_fit="contain"
790
  )
791
 
792
- def check_and_display_image(job_id):
793
- """Check status and display both overlay and segmented objects"""
794
- if not job_id or job_id not in processing_results:
795
- return None, "Không tìm thấy công việc", []
 
 
 
 
 
 
 
 
 
 
 
 
796
 
797
- result = processing_results[job_id]
798
 
799
- if result['status'] == 'processing':
800
- return None, f"⏳ Đang xử lý... {result['progress']}%", []
801
- elif result['status'] == 'completed':
802
- job_result = result['result']
803
- segmented_files = job_result.get('segmented_files', [])
804
-
805
- # Create gallery items
806
- gallery_items = []
807
- for i, seg_file in enumerate(segmented_files, 1):
808
- if os.path.exists(seg_file):
809
- gallery_items.append(seg_file)
810
-
811
- status_msg = f"✅ Hoàn thành! Đã tách được {len(gallery_items)} đối tượng"
812
- return job_result['image'], status_msg, gallery_items
813
  else:
814
- return None, f"❌ Lỗi: {result.get('error', 'Unknown')}", []
815
 
816
- btn_submit_img.click(
817
- fn=submit_image_job,
818
- inputs=[image_input, txt_prompt_img, conf_slider],
819
- outputs=[image_result, status_img, job_id_img]
820
  )
821
 
822
- btn_check_img.click(
823
- fn=check_and_display_image,
824
- inputs=[job_id_img],
825
- outputs=[image_result, status_img, segmented_gallery]
826
  )
827
 
828
- # ===== VIDEO SEGMENTATION TAB =====
829
  with gr.Tab("🎥 Video Segmentation"):
830
  with gr.Row():
831
  with gr.Column():
832
- video_input = gr.Video(label="Upload Video", format="mp4", height=320)
833
- txt_prompt_vid = gr.Textbox(label="Text Prompt", placeholder="e.g., person running, red car")
834
-
835
- with gr.Row():
836
- frame_limiter = gr.Slider(10, 500, value=60, step=10, label="Max Frames")
837
- time_limiter = gr.Radio([60, 120, 180], value=60, label="Timeout (seconds)")
838
 
839
- btn_submit_vid = gr.Button("🚀 Submit Job (Background)", variant="primary")
840
- btn_check_vid = gr.Button("🔍 Check Status", variant="secondary")
841
- job_id_vid = gr.Textbox(label="Job ID", visible=False)
 
 
 
842
 
 
 
 
 
843
  with gr.Column():
844
- gr.Markdown("### 📹 Video Outputs")
845
 
846
  with gr.Tabs():
847
- with gr.Tab("Overlay"):
848
- video_result_overlay = gr.Video(label="1. Overlay (Original + Masks)")
 
849
 
850
- with gr.Tab("Masks Only"):
851
- video_result_masks = gr.Video(label="2. Masks Only (White on Black)")
 
852
 
853
- with gr.Tab("Segmented"):
854
- video_result_segmented = gr.Video(label="3. Segmented (Green Screen Background)")
 
855
 
856
- status_vid = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
857
 
858
- def check_and_display_video(job_id):
859
- """Check status and display all video outputs"""
860
- if not job_id or job_id not in processing_results:
861
- return None, None, None, "Không tìm thấy công việc"
862
 
863
- result = processing_results[job_id]
864
 
865
- if result['status'] == 'processing':
866
- status = f"⏳ Đang xử lý... {result['progress']}%"
867
- return None, None, None, status
868
- elif result['status'] == 'completed':
869
- job_result = result['result']
870
- overlay = job_result.get('output_path')
871
- masks = job_result.get('mask_video_path')
872
- segmented = job_result.get('segmented_video_path')
873
-
874
- status = "✅ Hoàn thành! 3 video đã được tạo:\n"
875
- status += "1️⃣ Overlay - Ảnh gốc với mask màu\n"
876
- status += "2️⃣ Masks Only - Chỉ mask (trắng/đen)\n"
877
- status += "3️⃣ Segmented - Đối tượng với green screen"
878
 
879
- return overlay, masks, segmented, status
 
 
 
 
 
 
 
 
 
880
  else:
881
- error_msg = f"❌ Lỗi: {result.get('error', 'Unknown')}"
882
- return None, None, None, error_msg
883
-
884
- btn_submit_vid.click(
885
- fn=submit_video_job,
886
- inputs=[video_input, txt_prompt_vid, frame_limiter, time_limiter],
887
- outputs=[video_result_overlay, status_vid, job_id_vid]
888
- )
889
 
890
- btn_check_vid.click(
891
- fn=check_and_display_video,
892
- inputs=[job_id_vid],
893
- outputs=[video_result_overlay, video_result_masks, video_result_segmented, status_vid]
894
  )
895
 
896
- # ===== CLICK SEGMENTATION TAB =====
 
 
 
 
 
 
897
  with gr.Tab("👆 Click Segmentation"):
898
  with gr.Row():
899
  with gr.Column(scale=1):
900
- img_click_input = gr.Image(type="pil", label="Upload Image", interactive=True, height=450)
901
- gr.Markdown("**Hướng dẫn:** Click vào đối tượng bạn muốn phân đoạn")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
902
 
903
  with gr.Row():
904
- img_click_clear = gr.Button("🔄 Clear Points & Reset", variant="primary")
 
905
 
906
- st_click_points = gr.State([])
907
- st_click_labels = gr.State([])
908
-
909
  with gr.Column(scale=1):
910
- img_click_output = gr.Image(type="pil", label="Result Preview", height=450, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911
 
912
- img_click_input.select(
913
- image_click_handler,
914
- inputs=[img_click_input, st_click_points, st_click_labels],
915
- outputs=[img_click_output, st_click_points, st_click_labels]
916
  )
917
 
918
- img_click_clear.click(
919
- lambda: (None, [], []),
920
- outputs=[img_click_output, st_click_points, st_click_labels]
 
921
  )
922
 
923
- # ===== ADVANCED HISTORY TAB =====
924
- with gr.Tab("📊 Lịch Sử & Thống Kê"):
925
  with gr.Row():
926
- # Statistics Dashboard
927
  with gr.Column(scale=1):
928
- gr.Markdown("### 📈 Thống Kê Tổng Quan")
929
 
930
  def update_stats():
931
  stats = get_history_stats()
932
  return (
933
- f"**{stats['total']}** Tổng số",
934
- f"**{stats['completed']}** Hoàn thành",
935
- f"**{stats['errors']}** Lỗi",
936
- f"**{stats['success_rate']}** Tỷ lệ thành công"
937
  )
938
 
939
  with gr.Row():
940
- stat_total = gr.Markdown("**0** Tổng số")
941
- stat_completed = gr.Markdown("**0** Hoàn thành")
942
- with gr.Row():
943
- stat_errors = gr.Markdown("**0** Lỗi")
944
- stat_success = gr.Markdown("**0%** Tỷ lệ thành công")
945
 
946
- gr.Markdown("### 🎯 Thao Tác Nhanh")
947
  with gr.Row():
948
- btn_refresh = gr.Button("🔄 Refresh", variant="primary", scale=1)
949
- btn_export = gr.Button("📥 Export JSON", variant="secondary", scale=1)
 
 
 
950
  with gr.Row():
951
- btn_clear_all = gr.Button("🗑️ Clear All History", variant="stop", scale=1)
 
 
 
952
 
953
  export_file = gr.File(label="Exported File", visible=False)
954
  clear_status = gr.Textbox(label="Status", interactive=False)
955
 
956
- # History Table
957
  with gr.Row():
958
  with gr.Column():
959
- gr.Markdown("### 📜 Lịch Sử Chi Tiết")
960
 
961
- # Search and Filter
962
  with gr.Row():
963
  search_input = gr.Textbox(
964
  placeholder="🔍 Tìm kiếm theo prompt...",
@@ -980,19 +952,17 @@ with gr.Blocks(css=custom_css, theme=app_theme) as demo:
980
 
981
  history_table = gr.HTML(value=format_history_table())
982
 
983
- # Gallery View
984
  with gr.Row():
985
  with gr.Column():
986
- gr.Markdown("### 🖼️ Gallery - Kết Quả Gần Đây")
987
  history_gallery = gr.Gallery(
988
  value=get_history_gallery(),
989
- label="Recent Outputs",
990
  columns=4,
991
  height=400,
992
  object_fit="contain"
993
  )
994
 
995
- # Event handlers
996
  def refresh_all():
997
  return (
998
  *update_stats(),
@@ -1018,38 +988,40 @@ with gr.Blocks(css=custom_css, theme=app_theme) as demo:
1018
  outputs=[stat_total, stat_completed, stat_errors, stat_success, history_table, history_gallery]
1019
  )
1020
 
1021
- # Auto-refresh when searching/filtering
1022
  def filter_and_display(keyword, ftype, fstatus):
1023
  filtered = search_history(keyword, ftype, fstatus)
1024
- # Format filtered results
1025
  if not filtered:
1026
- return "<p style='text-align:center; color:#666;'>Không tìm thấy kết quả</p>"
1027
-
1028
- # Reuse formatting logic
1029
- html = format_history_table()
1030
- return html
1031
 
1032
  search_input.change(
1033
  fn=filter_and_display,
1034
  inputs=[search_input, filter_type, filter_status],
1035
  outputs=[history_table]
1036
  )
 
1037
  filter_type.change(
1038
  fn=filter_and_display,
1039
  inputs=[search_input, filter_type, filter_status],
1040
  outputs=[history_table]
1041
  )
 
1042
  filter_status.change(
1043
  fn=filter_and_display,
1044
  inputs=[search_input, filter_type, filter_status],
1045
  outputs=[history_table]
1046
  )
 
 
 
 
 
 
1047
 
1048
  if __name__ == "__main__":
1049
  demo.launch(
1050
  css=custom_css,
1051
  theme=app_theme,
1052
  ssr_mode=False,
1053
- mcp_server=True,
1054
  show_error=True
1055
  )
 
1
  import os
2
  import cv2
 
3
  import spaces
4
  import gradio as gr
5
  import numpy as np
6
  import torch
7
  import matplotlib
 
8
  from PIL import Image, ImageDraw
9
  from typing import Iterable
10
  from gradio.themes import Soft
 
20
  import queue
21
  import uuid
22
  import shutil
23
+ import zipfile
24
 
25
  # ============ THEME SETUP ============
26
  colors.steel_blue = colors.Color(
27
  name="steel_blue",
28
+ c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2",
29
+ c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C",
30
+ c800="#2E5378", c900="#264364", c950="#1E3450"
 
 
 
 
 
 
 
 
31
  )
32
 
33
  class CustomBlueTheme(Soft):
34
+ def __init__(self, *, primary_hue=colors.gray, secondary_hue=colors.steel_blue,
35
+ neutral_hue=colors.slate, text_size=sizes.text_lg,
36
+ font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"),
37
+ font_mono=(fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace")):
38
+ super().__init__(primary_hue=primary_hue, secondary_hue=secondary_hue,
39
+ neutral_hue=neutral_hue, text_size=text_size, font=font, font_mono=font_mono)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  super().set(
41
  background_fill_primary="*primary_50",
42
  background_fill_primary_dark="*primary_900",
 
63
 
64
  # ============ GLOBAL SETUP ============
65
  device = "cuda" if torch.cuda.is_available() else "cpu"
66
+ print(f"🖥️ Using device: {device}")
67
 
 
68
  HISTORY_DIR = "processing_history"
69
  OUTPUTS_DIR = os.path.join(HISTORY_DIR, "outputs")
70
+ DOWNLOADS_DIR = os.path.join(HISTORY_DIR, "downloads")
71
  os.makedirs(OUTPUTS_DIR, exist_ok=True)
72
+ os.makedirs(DOWNLOADS_DIR, exist_ok=True)
73
  HISTORY_FILE = os.path.join(HISTORY_DIR, "history.json")
74
 
 
75
  processing_queue = queue.Queue()
76
  processing_results = {}
77
 
78
  # Load models
79
+ print("⏳ Loading SAM3 Models...")
80
  try:
81
+ print(" Loading Image Model...")
82
  IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3").to(device)
83
  IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3")
84
+
85
+ print(" Loading Tracker Model...")
86
  TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3").to(device)
87
  TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3")
88
+
89
+ print(" Loading Video Model...")
90
  VID_MODEL = Sam3VideoModel.from_pretrained("DiffusionWave/sam3").to(device, dtype=torch.bfloat16)
91
  VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("DiffusionWave/sam3")
92
 
93
+ print("✅ All models loaded successfully!")
94
  except Exception as e:
95
+ print(f"❌ Error loading models: {e}")
96
  IMG_MODEL = IMG_PROCESSOR = TRK_MODEL = TRK_PROCESSOR = VID_MODEL = VID_PROCESSOR = None
97
 
98
  # ============ HISTORY MANAGEMENT ============
99
  def load_history():
 
100
  if os.path.exists(HISTORY_FILE):
101
  try:
102
  with open(HISTORY_FILE, 'r', encoding='utf-8') as f:
 
105
  return []
106
  return []
107
 
108
+ def save_history(item):
 
109
  history = load_history()
110
+ history.insert(0, item)
111
+ history = history[:200]
112
  with open(HISTORY_FILE, 'w', encoding='utf-8') as f:
113
  json.dump(history, f, indent=2, ensure_ascii=False)
114
 
115
  def get_history_stats():
 
116
  history = load_history()
117
  total = len(history)
118
  completed = sum(1 for h in history if h['status'] == 'completed')
119
  errors = sum(1 for h in history if h['status'] == 'error')
 
120
  types = {}
121
  for h in history:
122
  t = h['type']
123
  types[t] = types.get(t, 0) + 1
 
124
  return {
125
  'total': total,
126
  'completed': completed,
 
129
  'types': types
130
  }
131
 
132
+ def create_download_package(item_id):
133
+ history = load_history()
134
+ item = next((h for h in history if h['id'] == item_id), None)
135
+
136
+ if not item or item['status'] != 'completed':
137
+ return None
138
+
139
+ zip_path = os.path.join(DOWNLOADS_DIR, f"{item_id}_results.zip")
140
+
141
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
142
+ metadata = {
143
+ 'job_id': item_id,
144
+ 'type': item['type'],
145
+ 'prompt': item.get('prompt', 'N/A'),
146
+ 'timestamp': item['timestamp'],
147
+ 'duration': item.get('duration', 'N/A'),
148
+ 'num_objects': item.get('num_objects', 0)
149
+ }
150
+ zipf.writestr('metadata.json', json.dumps(metadata, indent=2, ensure_ascii=False))
151
+
152
+ if item['type'] == 'image':
153
+ if item.get('output_path') and os.path.exists(item['output_path']):
154
+ zipf.write(item['output_path'], 'overlay.jpg')
155
+ if item.get('segmented_files'):
156
+ for i, f in enumerate(item['segmented_files'], 1):
157
+ if os.path.exists(f):
158
+ zipf.write(f, f'objects/object_{i}.png')
159
+
160
+ elif item['type'] == 'video':
161
+ if item.get('output_path') and os.path.exists(item['output_path']):
162
+ zipf.write(item['output_path'], 'overlay_video.mp4')
163
+ if item.get('mask_video_path') and os.path.exists(item['mask_video_path']):
164
+ zipf.write(item['mask_video_path'], 'masks_only.mp4')
165
+ if item.get('segmented_video_path') and os.path.exists(item['segmented_video_path']):
166
+ zipf.write(item['segmented_video_path'], 'segmented_video.mp4')
167
+
168
+ elif item['type'] == 'click':
169
+ if item.get('output_path') and os.path.exists(item['output_path']):
170
+ zipf.write(item['output_path'], 'result.jpg')
171
+
172
+ return zip_path
173
+
174
+ def get_downloadable_jobs():
175
+ history = load_history()
176
+ choices = []
177
+ for item in history:
178
+ if item['status'] == 'completed':
179
+ type_emoji = {'image': '📷', 'video': '🎥', 'click': '👆'}.get(item['type'], '📄')
180
+ label = f"{type_emoji} [{item['type'].upper()}] {item['prompt'][:35]}... | {item['timestamp']}"
181
+ choices.append((label, item['id']))
182
+ return choices if choices else [("No completed jobs available", None)]
183
+
184
  def format_history_table():
 
185
  history = load_history()
186
  if not history:
187
+ return "<p style='text-align:center; color:#666; padding:40px;'>📭 Chưa có lịch sử xử lý nào</p>"
188
 
189
  html = """
190
  <style>
191
+ .history-table { width: 100%; border-collapse: collapse; font-size: 14px; background: white; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
192
+ .history-table th { background: linear-gradient(90deg, #4682B4, #529AC3); color: white; padding: 14px 12px; text-align: left; font-weight: 600; text-transform: uppercase; font-size: 12px; letter-spacing: 0.5px; }
193
+ .history-table td { padding: 12px; border-bottom: 1px solid #e8e8e8; vertical-align: middle; }
194
+ .history-table tr:hover { background-color: #f8f9fa; }
195
+ .history-table tr:last-child td { border-bottom: none; }
196
+ .status-badge { padding: 5px 12px; border-radius: 14px; font-size: 11px; font-weight: 700; text-transform: uppercase; letter-spacing: 0.5px; display: inline-block; }
197
+ .status-completed { background: linear-gradient(135deg, #d4edda, #c3e6cb); color: #155724; }
198
+ .status-error { background: linear-gradient(135deg, #f8d7da, #f5c6cb); color: #721c24; }
199
+ .type-badge { padding: 5px 10px; border-radius: 10px; font-size: 11px; font-weight: 600; background: linear-gradient(135deg, #e3f2fd, #bbdefb); color: #1565c0; display: inline-block; }
200
+ .prompt-text { max-width: 280px; overflow: hidden; text-overflow: ellipsis; white-space: nowrap; color: #333; font-weight: 500; }
201
+ .file-count { font-size: 11px; color: #666; margin-top: 4px; line-height: 1.4; }
202
+ .job-id { font-family: 'Courier New', monospace; font-size: 10px; color: #999; background: #f5f5f5; padding: 3px 6px; border-radius: 4px; }
203
+ .time-info { font-size: 12px; color: #666; }
204
+ .duration { font-size: 11px; color: #999; margin-top: 3px; }
 
 
205
  </style>
206
  <table class='history-table'>
207
  <thead>
208
  <tr>
209
+ <th style='width: 40px; text-align: center;'>#</th>
210
+ <th style='width: 110px;'>Job ID</th>
211
+ <th style='width: 90px;'>Type</th>
212
+ <th style='width: 110px;'>Status</th>
213
  <th>Prompt</th>
214
+ <th style='width: 120px;'>Output Files</th>
215
+ <th style='width: 140px;'>Time</th>
 
216
  </tr>
217
  </thead>
218
  <tbody>
 
220
 
221
  for i, item in enumerate(history[:100], 1):
222
  status_class = f"status-{item['status']}"
223
+ status_text = "✅ Completed" if item['status'] == 'completed' else "❌ Error"
224
 
225
  type_icons = {'image': '📷', 'video': '🎥', 'click': '👆'}
226
  type_icon = type_icons.get(item['type'], '📄')
227
 
228
+ prompt = item.get('prompt', 'N/A')
229
+ prompt_short = prompt[:45] + ('...' if len(prompt) > 45 else '')
230
 
 
231
  file_info = []
232
  if item.get('output_path'):
233
+ file_info.append("Overlay")
234
  if item.get('segmented_files'):
235
+ file_info.append(f"{len(item['segmented_files'])} Objects")
236
  if item.get('mask_video_path'):
237
+ file_info.append("Masks")
238
  if item.get('segmented_video_path'):
239
+ file_info.append("Segmented")
 
 
240
 
241
+ files_text = "<br>".join(file_info) if file_info else "No files"
 
 
 
 
242
 
243
  html += f"""
244
  <tr>
245
+ <td style='text-align: center; font-weight: 600; color: #999;'>{i}</td>
246
+ <td><span class='job-id'>{item['id'][:12]}</span></td>
247
  <td><span class='type-badge'>{type_icon} {item['type'].upper()}</span></td>
248
  <td><span class='status-badge {status_class}'>{status_text}</span></td>
249
+ <td class='prompt-text' title='{prompt}'>{prompt_short}</td>
250
  <td><div class='file-count'>{files_text}</div></td>
251
+ <td>
252
+ <div class='time-info'>{item['timestamp']}</div>
253
+ <div class='duration'>⏱️ {item.get('duration', 'N/A')}</div>
254
+ </td>
255
  </tr>
256
  """
257
 
258
  html += """
259
  </tbody>
260
  </table>
 
 
 
 
 
 
 
 
 
 
261
  """
262
 
263
  return html
264
 
265
  def get_history_gallery():
 
266
  history = load_history()
267
  gallery_items = []
268
 
269
+ for item in history[:30]:
270
+ if item['status'] == 'completed':
271
+ if item.get('output_path') and os.path.exists(item['output_path']):
272
+ caption = f"[{item['type'].upper()}] {item['prompt'][:35]}... | {item['timestamp']}"
273
+ gallery_items.append((item['output_path'], caption))
 
274
 
275
+ return gallery_items if gallery_items else []
276
 
277
  def search_history(keyword, filter_type, filter_status):
 
278
  history = load_history()
279
  filtered = history
280
 
281
  if keyword:
282
  filtered = [h for h in filtered if keyword.lower() in h.get('prompt', '').lower()]
 
283
  if filter_type and filter_type != "all":
284
  filtered = [h for h in filtered if h['type'] == filter_type]
 
285
  if filter_status and filter_status != "all":
286
  filtered = [h for h in filtered if h['status'] == filter_status]
287
 
288
  return filtered
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  def clear_all_history():
 
291
  if os.path.exists(OUTPUTS_DIR):
292
  shutil.rmtree(OUTPUTS_DIR)
293
  os.makedirs(OUTPUTS_DIR)
294
+ if os.path.exists(DOWNLOADS_DIR):
295
+ shutil.rmtree(DOWNLOADS_DIR)
296
+ os.makedirs(DOWNLOADS_DIR)
297
 
298
  with open(HISTORY_FILE, 'w', encoding='utf-8') as f:
299
  json.dump([], f)
300
 
301
+ return "✅ Đã xóa toàn bộ lịch sử và files"
302
 
303
  def export_history_json():
 
304
  history = load_history()
305
  export_path = os.path.join(HISTORY_DIR, f"history_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
306
 
 
309
 
310
  return export_path
311
 
312
+ # ============ PROCESSING UTILS ============
313
  def apply_mask_overlay(base_image, mask_data, opacity=0.5):
 
314
  if isinstance(base_image, np.ndarray):
315
  base_image = Image.fromarray(base_image)
316
  base_image = base_image.convert("RGBA")
 
322
  mask_data = mask_data.cpu().numpy()
323
  mask_data = mask_data.astype(np.uint8)
324
 
325
+ if mask_data.ndim == 4:
326
+ mask_data = mask_data[0]
327
+ if mask_data.ndim == 3 and mask_data.shape[0] == 1:
328
+ mask_data = mask_data[0]
329
 
330
  num_masks = mask_data.shape[0] if mask_data.ndim == 3 else 1
331
  if mask_data.ndim == 2:
 
334
 
335
  try:
336
  color_map = matplotlib.colormaps["rainbow"].resampled(max(num_masks, 1))
337
+ except:
338
  import matplotlib.cm as cm
339
  color_map = cm.get_cmap("rainbow").resampled(max(num_masks, 1))
340
 
341
  rgb_colors = [tuple(int(c * 255) for c in color_map(i)[:3]) for i in range(num_masks)]
342
  composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0))
343
 
344
+ for i, mask in enumerate(mask_data):
345
+ mask_img = Image.fromarray((mask * 255).astype(np.uint8))
346
+ if mask_img.size != base_image.size:
347
+ mask_img = mask_img.resize(base_image.size, resample=Image.NEAREST)
348
 
349
+ color_fill = Image.new("RGBA", base_image.size, rgb_colors[i] + (0,))
350
+ mask_alpha = mask_img.point(lambda v: int(v * opacity) if v > 0 else 0)
 
351
  color_fill.putalpha(mask_alpha)
352
  composite_layer = Image.alpha_composite(composite_layer, color_fill)
353
 
354
  return Image.alpha_composite(base_image, composite_layer).convert("RGB")
355
 
356
  def draw_points_on_image(image, points):
 
357
  if isinstance(image, np.ndarray):
358
  image = Image.fromarray(image)
 
359
  draw_img = image.copy()
360
  draw = ImageDraw.Draw(draw_img)
361
+ for x, y in points:
 
 
362
  r = 8
363
  draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white", width=4)
 
364
  return draw_img
365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  # ============ JOB PROCESSORS ============
367
  @spaces.GPU
368
  def process_image_job(job):
369
+ start = datetime.now()
370
+ img = job['image']
371
+ if isinstance(img, str):
372
+ img = Image.open(img)
 
373
 
374
+ img = img.convert("RGB")
375
+ inputs = IMG_PROCESSOR(images=img, text=job['prompt'], return_tensors="pt").to(device)
376
 
 
 
 
377
  with torch.no_grad():
378
+ outputs = IMG_MODEL(**inputs)
379
+
380
+ results = IMG_PROCESSOR.post_process_instance_segmentation(
381
+ outputs,
382
+ threshold=job.get('conf_thresh', 0.5),
383
  mask_threshold=0.5,
384
+ target_sizes=inputs.get("original_sizes").tolist()
385
  )[0]
 
 
 
 
386
 
387
+ masks = results['masks'].cpu().numpy()
388
+ scores = results['scores'].cpu().numpy()
389
+ annotations = [(m, f"{job['prompt']} ({s:.2f})") for m, s in zip(masks, scores)]
390
 
391
+ out_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_overlay.jpg")
392
+ apply_mask_overlay(img, masks).save(out_path)
 
 
393
 
394
+ seg_files = []
395
+ for i, mask in enumerate(masks):
396
+ mask_bool = mask.astype(bool)
397
+ seg = Image.new("RGBA", img.size, (0, 0, 0, 0))
398
+ arr = np.array(img.convert("RGBA"))
399
+ arr[~mask_bool] = [0, 0, 0, 0]
400
+ seg = Image.fromarray(arr)
401
+ bbox = Image.fromarray(mask * 255).getbbox()
 
 
 
 
 
 
 
 
402
  if bbox:
403
+ seg_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_obj_{i+1}.png")
404
+ seg.crop(bbox).save(seg_path)
405
+ seg_files.append(seg_path)
 
 
 
406
 
407
  return {
408
+ 'image': (img, annotations),
409
+ 'output_path': out_path,
410
+ 'segmented_files': seg_files,
411
+ 'num_objects': len(seg_files),
412
+ 'duration': f"{(datetime.now() - start).total_seconds():.2f}s"
413
  }
414
 
415
  @spaces.GPU
416
  def process_video_job(job):
417
+ start = datetime.now()
418
+ cap = cv2.VideoCapture(job['video'])
419
+ fps = cap.get(cv2.CAP_PROP_FPS)
420
+ w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
421
 
422
+ frames = []
423
+ limit = job.get('frame_limit', 0)
424
+ count = 0
425
+ while cap.isOpened():
426
+ ret, frame = cap.read()
427
+ if not ret or (limit > 0 and count >= limit):
428
+ break
429
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
430
+ count += 1
431
+ cap.release()
432
 
433
+ session = VID_PROCESSOR.init_video_session(video=frames, inference_device=device, dtype=torch.bfloat16)
434
+ session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=job['prompt'])
 
 
 
 
 
 
435
 
436
+ out_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_overlay.mp4")
437
+ mask_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_masks.mp4")
438
+ seg_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_segmented.mp4")
439
 
440
+ writers = [
441
+ cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)),
442
+ cv2.VideoWriter(mask_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)),
443
+ cv2.VideoWriter(seg_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
444
+ ]
445
 
446
+ total = len(frames)
447
+ for idx, out in enumerate(VID_MODEL.propagate_in_video_iterator(inference_session=session, max_frame_num_to_track=total)):
448
+ proc = VID_PROCESSOR.postprocess_outputs(session, out)
449
+ orig = Image.fromarray(frames[out.frame_idx])
 
 
 
 
 
 
 
 
 
450
 
451
+ if 'masks' in proc:
452
+ masks = proc['masks']
453
+ if masks.ndim == 4:
454
+ masks = masks.squeeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
 
456
+ overlay = apply_mask_overlay(orig, masks)
457
+ writers[0].write(cv2.cvtColor(np.array(overlay), cv2.COLOR_RGB2BGR))
458
 
459
+ mask_np = masks.cpu().numpy() if isinstance(masks, torch.Tensor) else masks
460
+ combined = np.zeros((h, w), dtype=np.uint8)
461
+ for m in mask_np:
462
+ if m.shape != (h, w):
463
+ m = cv2.resize(m.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST)
464
+ combined = np.maximum(combined, m)
465
 
466
+ mask_frame = np.zeros((h, w, 3), dtype=np.uint8)
467
+ mask_frame[combined > 0] = [255, 255, 255]
468
+ writers[1].write(mask_frame)
469
 
470
+ seg_arr = np.array(orig.convert("RGBA"))
471
+ seg_arr[:, :, 3] = (combined * 255).astype(np.uint8)
472
+ bgr = np.zeros((h, w, 3), dtype=np.uint8)
473
+ bgr[:, :] = [0, 255, 0]
474
  for c in range(3):
475
+ bgr[:, :, c] = np.where(combined > 0, seg_arr[:, :, 2-c], bgr[:, :, c])
476
+ writers[2].write(bgr)
477
+ else:
478
+ orig_bgr = cv2.cvtColor(np.array(orig), cv2.COLOR_RGB2BGR)
479
+ writers[0].write(orig_bgr)
480
+ writers[1].write(np.zeros((h, w, 3), dtype=np.uint8))
481
+ writers[2].write(orig_bgr)
 
 
 
 
 
 
 
 
482
 
483
+ processing_results[job['id']]['progress'] = int((idx + 1) / total * 100)
 
 
484
 
485
+ for w in writers:
486
+ w.release()
487
 
488
  return {
489
+ 'output_path': out_path,
490
+ 'mask_video_path': mask_path,
491
+ 'segmented_video_path': seg_path,
492
+ 'duration': f"{(datetime.now() - start).total_seconds():.2f}s"
493
  }
494
 
495
  @spaces.GPU
496
  def process_click_job(job):
497
+ start = datetime.now()
498
+ img = job['image']
499
+ if isinstance(img, str):
500
+ img = Image.open(img)
 
 
 
 
501
 
502
+ inputs = TRK_PROCESSOR(
503
+ images=img,
504
+ input_points=[[job['points']]],
505
+ input_labels=[[job['labels']]],
506
+ return_tensors="pt"
507
+ ).to(device)
508
 
509
  with torch.no_grad():
510
  outputs = TRK_MODEL(**inputs, multimask_output=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
 
512
+ masks = TRK_PROCESSOR.post_process_masks(
513
+ outputs.pred_masks.cpu(),
514
+ inputs["original_sizes"],
515
+ binarize=True
516
+ )[0]
 
517
 
518
+ result = apply_mask_overlay(img, masks[0])
519
+ result = draw_points_on_image(result, job['points'])
520
 
521
+ out_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_result.jpg")
522
+ result.save(out_path)
 
 
 
 
 
 
 
 
523
 
524
+ return {
525
+ 'image': result,
526
+ 'output_path': out_path,
527
+ 'duration': f"{(datetime.now() - start).total_seconds():.2f}s"
 
 
 
 
528
  }
 
 
 
529
 
530
+ # ============ BACKGROUND WORKER ============
531
+ def background_worker():
532
+ while True:
533
+ job = processing_queue.get()
534
+ if job is None:
535
+ break
536
+
537
+ processing_results[job['id']] = {'status': 'processing', 'progress': 0}
538
+
539
+ try:
540
+ if job['type'] == 'image':
541
+ result = process_image_job(job)
542
+ elif job['type'] == 'video':
543
+ result = process_video_job(job)
544
+ elif job['type'] == 'click':
545
+ result = process_click_job(job)
546
+
547
+ processing_results[job['id']] = {
548
+ 'status': 'completed',
549
+ 'result': result,
550
+ 'progress': 100
551
+ }
552
+
553
+ save_history({
554
+ 'id': job['id'],
555
+ 'type': job['type'],
556
+ 'prompt': job.get('prompt', 'N/A'),
557
+ 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
558
+ 'status': 'completed',
559
+ **result
560
+ })
561
+
562
+ except Exception as e:
563
+ processing_results[job['id']] = {
564
+ 'status': 'error',
565
+ 'error': str(e),
566
+ 'progress': 0
567
+ }
568
+ save_history({
569
+ 'id': job['id'],
570
+ 'type': job['type'],
571
+ 'prompt': job.get('prompt', 'N/A'),
572
+ 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
573
+ 'status': 'error',
574
+ 'error': str(e)
575
+ })
576
 
577
+ threading.Thread(target=background_worker, daemon=True).start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578
 
579
+ # ============ GRADIO UI ============
580
+ custom_css = """
581
+ #col-container { margin: 0 auto; max-width: 1400px; }
582
+ #main-title h1 { font-size: 2.2em !important; font-weight: 700; background: linear-gradient(135deg, #4682B4, #764ba2); -webkit-background-clip: text; -webkit-text-fill-color: transparent; }
583
+ .stat-card { padding: 24px; border-radius: 16px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; text-align: center; box-shadow: 0 4px 15px rgba(0,0,0,0.2); }
584
+ .stat-number { font-size: 2.8em; font-weight: 800; margin: 12px 0; text-shadow: 2px 2px 4px rgba(0,0,0,0.2); }
585
+ .stat-label { font-size: 1.1em; opacity: 0.95; font-weight: 500; }
586
  """
587
 
588
+ with gr.Blocks(css=custom_css, theme=app_theme, title="SAM3 Segmentation") as demo:
589
  with gr.Column(elem_id="col-container"):
590
  gr.Markdown("# **SAM3: Segment Anything Model 3** 🚀", elem_id="main-title")
591
+ gr.Markdown("### Xử lý ảnh/video với background processing - Không giới hạn thời gian - Download đầy đủ kết quả")
592
+
593
  with gr.Tabs():
594
+ # ===== IMAGE TAB =====
595
  with gr.Tab("📷 Image Segmentation"):
596
  with gr.Row():
597
  with gr.Column(scale=1):
598
+ img_input = gr.Image(label="📤 Upload Image", type="pil", height=350)
599
+ img_prompt = gr.Textbox(
600
+ label="✍️ Text Prompt",
601
+ placeholder="e.g., cat, person, car, building...",
602
+ lines=2
603
+ )
604
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
605
+ img_conf = gr.Slider(0.0, 1.0, 0.45, 0.05, label="Confidence Threshold")
606
 
607
+ img_submit = gr.Button("🚀 Submit Job (Background)", variant="primary", size="lg")
608
+ img_check = gr.Button("🔍 Check Status", variant="secondary")
609
+ img_job_id = gr.Textbox(label="Job ID", visible=False)
610
+
611
  with gr.Column(scale=1.5):
612
+ img_result = gr.AnnotatedImage(label="🎨 Segmented Result (Overlay)", height=410)
613
+ img_status = gr.Textbox(label="📊 Status", interactive=False)
614
 
615
  with gr.Accordion("📦 Extracted Objects", open=True):
616
+ gr.Markdown("**Các đối tượng được tách ra (PNG với nền trong suốt):**")
617
+ img_gallery = gr.Gallery(
618
+ label="Segmented Objects",
619
  columns=3,
620
  height=300,
621
  object_fit="contain"
622
  )
623
 
624
+ def submit_img(img, prompt, conf):
625
+ if not img or not prompt:
626
+ return None, "❌ Vui lòng cung cấp ảnh và prompt", "", []
627
+ jid = str(uuid.uuid4())
628
+ processing_queue.put({
629
+ 'id': jid,
630
+ 'type': 'image',
631
+ 'image': img,
632
+ 'prompt': prompt,
633
+ 'conf_thresh': conf
634
+ })
635
+ return None, f"✅ Đã thêm vào hàng chờ (ID: {jid[:8]}). Đang xử lý...", jid, []
636
+
637
+ def check_img(jid):
638
+ if not jid or jid not in processing_results:
639
+ return None, "❌ Không tìm thấy công việc", []
640
 
641
+ r = processing_results[jid]
642
 
643
+ if r['status'] == 'processing':
644
+ return None, f"⏳ Đang xử lý... {r['progress']}%", []
645
+ elif r['status'] == 'completed':
646
+ res = r['result']
647
+ gal = [f for f in res.get('segmented_files', []) if os.path.exists(f)]
648
+ status = f"✅ Hoàn thành! Đã tách được {len(gal)} đối tượng | Thời gian: {res.get('duration', 'N/A')}"
649
+ return res['image'], status, gal
 
 
 
 
 
 
 
650
  else:
651
+ return None, f"❌ Lỗi: {r.get('error', 'Unknown')}", []
652
 
653
+ img_submit.click(
654
+ fn=submit_img,
655
+ inputs=[img_input, img_prompt, img_conf],
656
+ outputs=[img_result, img_status, img_job_id, img_gallery]
657
  )
658
 
659
+ img_check.click(
660
+ fn=check_img,
661
+ inputs=[img_job_id],
662
+ outputs=[img_result, img_status, img_gallery]
663
  )
664
 
665
+ # ===== VIDEO TAB =====
666
  with gr.Tab("🎥 Video Segmentation"):
667
  with gr.Row():
668
  with gr.Column():
669
+ vid_input = gr.Video(label="📤 Upload Video", format="mp4", height=320)
670
+ vid_prompt = gr.Textbox(
671
+ label="✍️ Text Prompt",
672
+ placeholder="e.g., person running, red car, dog...",
673
+ lines=2
674
+ )
675
 
676
+ with gr.Accordion("⚙️ Settings", open=True):
677
+ vid_frames = gr.Slider(
678
+ 10, 500, 60, 10,
679
+ label="Max Frames (0 = All frames)",
680
+ info="Giới hạn số frame để xử lý nhanh hơn"
681
+ )
682
 
683
+ vid_submit = gr.Button("🚀 Submit Job (Background)", variant="primary", size="lg")
684
+ vid_check = gr.Button("🔍 Check Status", variant="secondary")
685
+ vid_job_id = gr.Textbox(label="Job ID", visible=False)
686
+
687
  with gr.Column():
688
+ gr.Markdown("### 📹 Video Outputs (3 versions)")
689
 
690
  with gr.Tabs():
691
+ with gr.Tab("1️⃣ Overlay"):
692
+ vid_overlay = gr.Video(label="Original + Color Masks")
693
+ gr.Markdown("*Video gốc với màu mask phủ lên*")
694
 
695
+ with gr.Tab("2️⃣ Masks Only"):
696
+ vid_masks = gr.Video(label="White Masks on Black")
697
+ gr.Markdown("*Chỉ hiển thị mask màu trắng trên nền đen*")
698
 
699
+ with gr.Tab("3️⃣ Segmented"):
700
+ vid_segmented = gr.Video(label="Green Screen Background")
701
+ gr.Markdown("*Đối tượng với nền xanh lá (green screen)*")
702
 
703
+ vid_status = gr.Textbox(label="📊 Status", interactive=False)
704
+
705
+ def submit_vid(vid, prompt, frames):
706
+ if not vid or not prompt:
707
+ return None, None, None, "❌ Vui lòng cung cấp video và prompt", ""
708
+ jid = str(uuid.uuid4())
709
+ processing_queue.put({
710
+ 'id': jid,
711
+ 'type': 'video',
712
+ 'video': vid,
713
+ 'prompt': prompt,
714
+ 'frame_limit': frames
715
+ })
716
+ return None, None, None, f"✅ Đã thêm vào hàng chờ (ID: {jid[:8]}). Đang xử lý...", jid
717
 
718
+ def check_vid(jid):
719
+ if not jid or jid not in processing_results:
720
+ return None, None, None, "❌ Không tìm thấy công việc"
 
721
 
722
+ r = processing_results[jid]
723
 
724
+ if r['status'] == 'processing':
725
+ return None, None, None, f"⏳ Đang xử lý... {r['progress']}%"
726
+ elif r['status'] == 'completed':
727
+ res = r['result']
728
+ status = f"""✅ Hoàn thành! Thời gian: {res.get('duration', 'N/A')}
 
 
 
 
 
 
 
 
729
 
730
+ 📹 3 video đã được tạo:
731
+ • Overlay - Ảnh gốc với mask màu
732
+ • Masks Only - Chỉ mask (trắng/đen)
733
+ • Segmented - Đối tượng với green screen"""
734
+ return (
735
+ res.get('output_path'),
736
+ res.get('mask_video_path'),
737
+ res.get('segmented_video_path'),
738
+ status
739
+ )
740
  else:
741
+ return None, None, None, f"❌ Lỗi: {r.get('error', 'Unknown')}"
 
 
 
 
 
 
 
742
 
743
+ vid_submit.click(
744
+ fn=submit_vid,
745
+ inputs=[vid_input, vid_prompt, vid_frames],
746
+ outputs=[vid_overlay, vid_masks, vid_segmented, vid_status, vid_job_id]
747
  )
748
 
749
+ vid_check.click(
750
+ fn=check_vid,
751
+ inputs=[vid_job_id],
752
+ outputs=[vid_overlay, vid_masks, vid_segmented, vid_status]
753
+ )
754
+
755
+ # ===== CLICK TAB =====
756
  with gr.Tab("👆 Click Segmentation"):
757
  with gr.Row():
758
  with gr.Column(scale=1):
759
+ click_input = gr.Image(
760
+ type="pil",
761
+ label="📤 Upload Image & Click Objects",
762
+ interactive=True,
763
+ height=450
764
+ )
765
+ gr.Markdown("""
766
+ **📝 Hướng dẫn:**
767
+ 1. Upload ảnh
768
+ 2. Click vào đối tượng bạn muốn phân đoạn
769
+ 3. Kết quả hiển thị ngay lập tức
770
+ 4. Click "Clear" để reset và bắt đầu lại
771
+ """)
772
+
773
+ click_clear = gr.Button("🔄 Clear Points & Reset", variant="primary")
774
+
775
+ click_pts = gr.State([])
776
+ click_lbl = gr.State([])
777
+
778
+ with gr.Column(scale=1):
779
+ click_output = gr.Image(
780
+ type="pil",
781
+ label="🎨 Result Preview",
782
+ height=450,
783
+ interactive=False
784
+ )
785
+
786
+ gr.Markdown("""
787
+ **💡 Tips:**
788
+ - Click vào trung tâm của đối tượng để có kết quả tốt nhất
789
+ - Các điểm click được hiển thị bằng dấu chấm đỏ
790
+ - Kết quả tự động cập nhật sau mỗi lần click
791
+ """)
792
+
793
+ def on_click(img, evt: gr.SelectData, pts, lbl):
794
+ if pts is None:
795
+ pts = []
796
+ if lbl is None:
797
+ lbl = []
798
+
799
+ pts.append([evt.index[0], evt.index[1]])
800
+ lbl.append(1)
801
+
802
+ jid = str(uuid.uuid4())
803
+ try:
804
+ res = process_click_job({
805
+ 'id': jid,
806
+ 'type': 'click',
807
+ 'image': img,
808
+ 'points': pts,
809
+ 'labels': lbl
810
+ })
811
+ return res['image'], pts, lbl
812
+ except Exception as e:
813
+ print(f"Click error: {e}")
814
+ return img, pts, lbl
815
+
816
+ click_input.select(
817
+ fn=on_click,
818
+ inputs=[click_input, click_pts, click_lbl],
819
+ outputs=[click_output, click_pts, click_lbl]
820
+ )
821
+
822
+ click_clear.click(
823
+ fn=lambda: (None, [], []),
824
+ outputs=[click_output, click_pts, click_lbl]
825
+ )
826
+
827
+ # ===== DOWNLOAD TAB =====
828
+ with gr.Tab("📥 Download Results"):
829
+ gr.Markdown("""
830
+ # 📦 Download Center
831
+ ### Tải về kết quả đã xử lý dưới dạng ZIP
832
+ """)
833
+
834
+ with gr.Row():
835
+ with gr.Column(scale=1):
836
+ gr.Markdown("### 🎯 Select Job to Download")
837
+
838
+ download_dropdown = gr.Dropdown(
839
+ label="Chọn công việc đã hoàn thành",
840
+ choices=get_downloadable_jobs(),
841
+ interactive=True,
842
+ scale=1
843
+ )
844
 
845
  with gr.Row():
846
+ download_refresh = gr.Button("🔄 Refresh List", variant="secondary", scale=1)
847
+ download_btn = gr.Button("📥 Download ZIP", variant="primary", size="lg", scale=2)
848
 
849
+ download_status = gr.Textbox(label="Status", interactive=False)
850
+
 
851
  with gr.Column(scale=1):
852
+ gr.Markdown("### 📄 Download File")
853
+ download_file = gr.File(label="Your ZIP file will appear here")
854
+
855
+ gr.Markdown("""
856
+ **📦 Package Contents:**
857
+
858
+ **Image Jobs:**
859
+ - `overlay.jpg` - Ảnh với mask màu
860
+ - `objects/object_*.png` - Từng đối tượng riêng lẻ (PNG transparent)
861
+ - `metadata.json` - Thông tin chi tiết
862
+
863
+ **Video Jobs:**
864
+ - `overlay_video.mp4` - Video với mask màu
865
+ - `masks_only.mp4` - Chỉ mask trắng/đen
866
+ - `segmented_video.mp4` - Video với green screen
867
+ - `metadata.json` - Thông tin chi tiết
868
+
869
+ **Click Jobs:**
870
+ - `result.jpg` - Ảnh kết quả
871
+ - `metadata.json` - Thông tin chi tiết
872
+ """)
873
+
874
+ def do_download(job_id):
875
+ if not job_id:
876
+ return None, "❌ Vui lòng chọn một job"
877
+
878
+ zip_path = create_download_package(job_id)
879
+ if zip_path and os.path.exists(zip_path):
880
+ size_mb = os.path.getsize(zip_path) / 1024 / 1024
881
+ return zip_path, f"✅ Sẵn sàng tải về! Kích thước: {size_mb:.2f} MB"
882
+
883
+ return None, "❌ Không thể tạo package. Job có thể đã bị xóa."
884
 
885
+ download_refresh.click(
886
+ fn=lambda: gr.Dropdown(choices=get_downloadable_jobs()),
887
+ outputs=[download_dropdown]
 
888
  )
889
 
890
+ download_btn.click(
891
+ fn=do_download,
892
+ inputs=[download_dropdown],
893
+ outputs=[download_file, download_status]
894
  )
895
 
896
+ # ===== HISTORY & STATS TAB =====
897
+ with gr.Tab("📊 History & Statistics"):
898
  with gr.Row():
 
899
  with gr.Column(scale=1):
900
+ gr.Markdown("### 📈 Statistics Dashboard")
901
 
902
  def update_stats():
903
  stats = get_history_stats()
904
  return (
905
+ f"**{stats['total']}**\n\nTổng số jobs",
906
+ f"**{stats['completed']}**\n\nHoàn thành",
907
+ f"**{stats['errors']}**\n\nLỗi",
908
+ f"**{stats['success_rate']}**\n\nTỷ lệ thành công"
909
  )
910
 
911
  with gr.Row():
912
+ stat_total = gr.Markdown("**0**\n\nTổng số jobs", elem_classes=["stat-card"])
913
+ stat_completed = gr.Markdown("**0**\n\nHoàn thành", elem_classes=["stat-card"])
 
 
 
914
 
 
915
  with gr.Row():
916
+ stat_errors = gr.Markdown("**0**\n\nLỗi", elem_classes=["stat-card"])
917
+ stat_success = gr.Markdown("**0%**\n\nTỷ lệ thành công", elem_classes=["stat-card"])
918
+
919
+ gr.Markdown("### 🎯 Quick Actions")
920
+
921
  with gr.Row():
922
+ btn_refresh = gr.Button("🔄 Refresh All", variant="primary")
923
+ btn_export = gr.Button("📥 Export JSON", variant="secondary")
924
+
925
+ btn_clear_all = gr.Button("🗑️ Clear All History", variant="stop")
926
 
927
  export_file = gr.File(label="Exported File", visible=False)
928
  clear_status = gr.Textbox(label="Status", interactive=False)
929
 
 
930
  with gr.Row():
931
  with gr.Column():
932
+ gr.Markdown("### 📜 Processing History")
933
 
 
934
  with gr.Row():
935
  search_input = gr.Textbox(
936
  placeholder="🔍 Tìm kiếm theo prompt...",
 
952
 
953
  history_table = gr.HTML(value=format_history_table())
954
 
 
955
  with gr.Row():
956
  with gr.Column():
957
+ gr.Markdown("### 🖼️ Gallery - Recent Outputs")
958
  history_gallery = gr.Gallery(
959
  value=get_history_gallery(),
960
+ label="Kết quả gần đây",
961
  columns=4,
962
  height=400,
963
  object_fit="contain"
964
  )
965
 
 
966
  def refresh_all():
967
  return (
968
  *update_stats(),
 
988
  outputs=[stat_total, stat_completed, stat_errors, stat_success, history_table, history_gallery]
989
  )
990
 
 
991
  def filter_and_display(keyword, ftype, fstatus):
992
  filtered = search_history(keyword, ftype, fstatus)
 
993
  if not filtered:
994
+ return "<p style='text-align:center; color:#666; padding:40px;'>🔍 Không tìm thấy kết quả phù hợp</p>"
995
+ return format_history_table()
 
 
 
996
 
997
  search_input.change(
998
  fn=filter_and_display,
999
  inputs=[search_input, filter_type, filter_status],
1000
  outputs=[history_table]
1001
  )
1002
+
1003
  filter_type.change(
1004
  fn=filter_and_display,
1005
  inputs=[search_input, filter_type, filter_status],
1006
  outputs=[history_table]
1007
  )
1008
+
1009
  filter_status.change(
1010
  fn=filter_and_display,
1011
  inputs=[search_input, filter_type, filter_status],
1012
  outputs=[history_table]
1013
  )
1014
+
1015
+ # Footer
1016
+ gr.Markdown("""
1017
+ ---
1018
+ **SAM3: Segment Anything Model 3** | Powered by DiffusionWave | Background Processing Enabled | No Timeout Limits
1019
+ """)
1020
 
1021
  if __name__ == "__main__":
1022
  demo.launch(
1023
  css=custom_css,
1024
  theme=app_theme,
1025
  ssr_mode=False,
 
1026
  show_error=True
1027
  )