Translsis commited on
Commit
25cc101
·
verified ·
1 Parent(s): 9a10593

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +503 -72
app.py CHANGED
@@ -21,6 +21,7 @@ from datetime import datetime
21
  import threading
22
  import queue
23
  import uuid
 
24
 
25
  # ============ THEME SETUP ============
26
  colors.steel_blue = colors.Color(
@@ -91,7 +92,8 @@ print(f"🖥️ Using compute device: {device}")
91
 
92
  # History storage
93
  HISTORY_DIR = "processing_history"
94
- os.makedirs(HISTORY_DIR, exist_ok=True)
 
95
  HISTORY_FILE = os.path.join(HISTORY_DIR, "history.json")
96
 
97
  # Background processing queue
@@ -123,7 +125,7 @@ def load_history():
123
  """Load processing history from JSON file"""
124
  if os.path.exists(HISTORY_FILE):
125
  try:
126
- with open(HISTORY_FILE, 'r') as f:
127
  return json.load(f)
128
  except:
129
  return []
@@ -132,26 +134,202 @@ def load_history():
132
  def save_history(history_item):
133
  """Save a new history item"""
134
  history = load_history()
135
- history.insert(0, history_item) # Add to beginning
136
- history = history[:100] # Keep last 100 items
137
- with open(HISTORY_FILE, 'w') as f:
138
- json.dump(history, f, indent=2)
139
 
140
- def get_history_display():
141
- """Format history for display"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  history = load_history()
143
  if not history:
144
- return "Chưa có lịch sử xử lý nào"
145
 
146
- display_text = ""
147
- for i, item in enumerate(history[:50], 1):
148
- status_emoji = "✅" if item['status'] == 'completed' else "❌"
149
- display_text += f"{status_emoji} **{item['type'].upper()}** - {item['timestamp']}\n"
150
- display_text += f" Prompt: {item['prompt']}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  if item.get('output_path'):
152
- display_text += f" File: `{os.path.basename(item['output_path'])}`\n"
153
- display_text += "\n"
154
- return display_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # ============ UTILITY FUNCTIONS ============
157
  def apply_mask_overlay(base_image, mask_data, opacity=0.5):
@@ -240,14 +418,18 @@ def background_worker():
240
  'progress': 100
241
  }
242
 
243
- # Save to history
244
  save_history({
245
  'id': job_id,
246
  'type': job_type,
247
  'prompt': job.get('prompt', 'N/A'),
248
  'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
249
  'status': 'completed',
250
- 'output_path': result.get('output_path')
 
 
 
 
 
251
  })
252
 
253
  except Exception as e:
@@ -267,7 +449,6 @@ def background_worker():
267
  except Exception as e:
268
  print(f"Worker error: {e}")
269
 
270
- # Start background worker
271
  worker_thread = threading.Thread(target=background_worker, daemon=True)
272
  worker_thread.start()
273
 
@@ -275,6 +456,7 @@ worker_thread.start()
275
  @spaces.GPU
276
  def process_image_job(job):
277
  """Process image segmentation job"""
 
278
  source_img = job['image']
279
  text_query = job['prompt']
280
  conf_thresh = job.get('conf_thresh', 0.5)
@@ -303,19 +485,47 @@ def process_image_job(job):
303
  label_str = f"{text_query} ({raw_scores[idx]:.2f})"
304
  annotation_list.append((mask_array, label_str))
305
 
306
- # Save output
307
- output_path = os.path.join(HISTORY_DIR, f"{job['id']}_result.jpg")
308
  result_img = apply_mask_overlay(pil_image, raw_masks)
309
  result_img.save(output_path)
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  return {
312
  'image': (pil_image, annotation_list),
313
- 'output_path': output_path
 
 
 
314
  }
315
 
316
  @spaces.GPU
317
  def process_video_job(job):
318
  """Process video segmentation job"""
 
319
  source_vid = job['video']
320
  text_query = job['prompt']
321
  frame_limit = job.get('frame_limit', 60)
@@ -337,9 +547,18 @@ def process_video_job(job):
337
  session = VID_PROCESSOR.init_video_session(video=video_frames, inference_device=device, dtype=torch.bfloat16)
338
  session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=text_query)
339
 
340
- output_path = os.path.join(HISTORY_DIR, f"{job['id']}_result.mp4")
 
341
  video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), vid_fps, (vid_w, vid_h))
342
 
 
 
 
 
 
 
 
 
343
  total_frames = len(video_frames)
344
  for frame_idx, model_out in enumerate(VID_MODEL.propagate_in_video_iterator(inference_session=session, max_frame_num_to_track=total_frames)):
345
  post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out)
@@ -349,22 +568,71 @@ def process_video_job(job):
349
  if 'masks' in post_processed:
350
  detected_masks = post_processed['masks']
351
  if detected_masks.ndim == 4: detected_masks = detected_masks.squeeze(1)
352
- final_frame = apply_mask_overlay(original_pil, detected_masks)
353
- else:
354
- final_frame = original_pil
355
 
356
- video_writer.write(cv2.cvtColor(np.array(final_frame), cv2.COLOR_RGB2BGR))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
- # Update progress
359
  progress = int((frame_idx + 1) / total_frames * 100)
360
  processing_results[job['id']]['progress'] = progress
361
 
362
  video_writer.release()
363
- return {'output_path': output_path}
 
 
 
 
 
 
 
 
 
 
364
 
365
  @spaces.GPU
366
  def process_click_job(job):
367
  """Process click segmentation job"""
 
368
  input_image = job['image']
369
  points_state = job['points']
370
  labels_state = job['labels']
@@ -384,17 +652,19 @@ def process_click_job(job):
384
  final_img = apply_mask_overlay(input_image, masks[0])
385
  final_img = draw_points_on_image(final_img, points_state)
386
 
387
- output_path = os.path.join(HISTORY_DIR, f"{job['id']}_result.jpg")
388
  final_img.save(output_path)
389
 
 
 
390
  return {
391
  'image': final_img,
392
- 'output_path': output_path
 
393
  }
394
 
395
  # ============ UI HANDLERS ============
396
  def submit_image_job(source_img, text_query, conf_thresh):
397
- """Submit image segmentation job to background queue"""
398
  if source_img is None or not text_query:
399
  return None, "❌ Vui lòng cung cấp ảnh và prompt", ""
400
 
@@ -411,7 +681,6 @@ def submit_image_job(source_img, text_query, conf_thresh):
411
  return None, f"✅ Đã thêm vào hàng chờ (ID: {job_id[:8]}). Đang xử lý...", job_id
412
 
413
  def check_image_status(job_id):
414
- """Check status of image processing job"""
415
  if not job_id or job_id not in processing_results:
416
  return None, "Không tìm thấy công việc"
417
 
@@ -425,7 +694,6 @@ def check_image_status(job_id):
425
  return None, f"❌ Lỗi: {result.get('error', 'Unknown')}"
426
 
427
  def submit_video_job(source_vid, text_query, frame_limit, time_limit):
428
- """Submit video segmentation job to background queue"""
429
  if not source_vid or not text_query:
430
  return None, "❌ Vui lòng cung cấp video và prompt", ""
431
 
@@ -443,7 +711,6 @@ def submit_video_job(source_vid, text_query, frame_limit, time_limit):
443
  return None, f"✅ Đã thêm vào hàng chờ (ID: {job_id[:8]}). Đang xử lý...", job_id
444
 
445
  def check_video_status(job_id):
446
- """Check status of video processing job"""
447
  if not job_id or job_id not in processing_results:
448
  return None, "Không tìm thấy công việc"
449
 
@@ -457,7 +724,6 @@ def check_video_status(job_id):
457
  return None, f"❌ Lỗi: {result.get('error', 'Unknown')}"
458
 
459
  def image_click_handler(image, evt: gr.SelectData, points_state, labels_state):
460
- """Handle click events for interactive segmentation"""
461
  x, y = evt.index
462
 
463
  if points_state is None: points_state = []
@@ -466,7 +732,6 @@ def image_click_handler(image, evt: gr.SelectData, points_state, labels_state):
466
  points_state.append([x, y])
467
  labels_state.append(1)
468
 
469
- # Process immediately (can be changed to background if needed)
470
  job_id = str(uuid.uuid4())
471
  job = {
472
  'id': job_id,
@@ -485,9 +750,11 @@ def image_click_handler(image, evt: gr.SelectData, points_state, labels_state):
485
 
486
  # ============ GRADIO INTERFACE ============
487
  custom_css="""
488
- #col-container { margin: 0 auto; max-width: 1200px; }
489
  #main-title h1 { font-size: 2.1em !important; }
490
- .history-box { max-height: 600px; overflow-y: auto; }
 
 
491
  """
492
 
493
  with gr.Blocks(css=custom_css, theme=app_theme) as demo:
@@ -510,8 +777,41 @@ with gr.Blocks(css=custom_css, theme=app_theme) as demo:
510
  job_id_img = gr.Textbox(label="Job ID", visible=False)
511
 
512
  with gr.Column(scale=1.5):
513
- image_result = gr.AnnotatedImage(label="Segmented Result", height=410)
514
  status_img = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
 
516
  btn_submit_img.click(
517
  fn=submit_image_job,
@@ -520,9 +820,9 @@ with gr.Blocks(css=custom_css, theme=app_theme) as demo:
520
  )
521
 
522
  btn_check_img.click(
523
- fn=check_image_status,
524
  inputs=[job_id_img],
525
- outputs=[image_result, status_img]
526
  )
527
 
528
  # ===== VIDEO SEGMENTATION TAB =====
@@ -541,19 +841,56 @@ with gr.Blocks(css=custom_css, theme=app_theme) as demo:
541
  job_id_vid = gr.Textbox(label="Job ID", visible=False)
542
 
543
  with gr.Column():
544
- video_result = gr.Video(label="Processed Video")
 
 
 
 
 
 
 
 
 
 
 
545
  status_vid = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
 
547
  btn_submit_vid.click(
548
  fn=submit_video_job,
549
  inputs=[video_input, txt_prompt_vid, frame_limiter, time_limiter],
550
- outputs=[video_result, status_vid, job_id_vid]
551
  )
552
 
553
  btn_check_vid.click(
554
- fn=check_video_status,
555
  inputs=[job_id_vid],
556
- outputs=[video_result, status_vid]
557
  )
558
 
559
  # ===== CLICK SEGMENTATION TAB =====
@@ -583,36 +920,130 @@ with gr.Blocks(css=custom_css, theme=app_theme) as demo:
583
  outputs=[img_click_output, st_click_points, st_click_labels]
584
  )
585
 
586
- # ===== HISTORY TAB =====
587
- with gr.Tab("📜 Lịch Sử Xử "):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
  with gr.Row():
589
  with gr.Column():
590
- btn_refresh_history = gr.Button("🔄 Refresh History", variant="primary")
591
- history_display = gr.Markdown(value=get_history_display(), elem_classes="history-box")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
 
593
- with gr.Accordion("Hướng dẫn", open=False):
594
- gr.Markdown("""
595
- ### Lịch sử lưu:
596
- - ✅ **Hoàn thành**: File đã được xử lý thành công
597
- - ❌ **Lỗi**: Xử lý thất bại
598
- - Tất cả file output được lưu trong thư mục `processing_history/`
599
- - Hệ thống giữ lại 100 lịch sử gần nhất
600
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
601
 
602
- btn_refresh_history.click(
603
- fn=get_history_display,
604
- outputs=[history_display]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
  )
606
-
607
- # ===== BATCH PROCESSING TAB =====
608
- with gr.Tab("⚙️ Batch Processing"):
609
- gr.Markdown("### Xử lý hàng loạt (Coming Soon)")
610
- gr.Markdown("""
611
- Tính năng này sẽ cho phép bạn:
612
- - Upload nhiều ảnh/video cùng lúc
613
- - Tự động xử lý tuần tự
614
- - Download tất cả kết quả dưới dạng ZIP
615
- """)
616
 
617
  if __name__ == "__main__":
618
  demo.launch(
 
21
  import threading
22
  import queue
23
  import uuid
24
+ import shutil
25
 
26
  # ============ THEME SETUP ============
27
  colors.steel_blue = colors.Color(
 
92
 
93
  # History storage
94
  HISTORY_DIR = "processing_history"
95
+ OUTPUTS_DIR = os.path.join(HISTORY_DIR, "outputs")
96
+ os.makedirs(OUTPUTS_DIR, exist_ok=True)
97
  HISTORY_FILE = os.path.join(HISTORY_DIR, "history.json")
98
 
99
  # Background processing queue
 
125
  """Load processing history from JSON file"""
126
  if os.path.exists(HISTORY_FILE):
127
  try:
128
+ with open(HISTORY_FILE, 'r', encoding='utf-8') as f:
129
  return json.load(f)
130
  except:
131
  return []
 
134
  def save_history(history_item):
135
  """Save a new history item"""
136
  history = load_history()
137
+ history.insert(0, history_item)
138
+ history = history[:200] # Keep last 200 items
139
+ with open(HISTORY_FILE, 'w', encoding='utf-8') as f:
140
+ json.dump(history, f, indent=2, ensure_ascii=False)
141
 
142
+ def get_history_stats():
143
+ """Get statistics from history"""
144
+ history = load_history()
145
+ total = len(history)
146
+ completed = sum(1 for h in history if h['status'] == 'completed')
147
+ errors = sum(1 for h in history if h['status'] == 'error')
148
+
149
+ types = {}
150
+ for h in history:
151
+ t = h['type']
152
+ types[t] = types.get(t, 0) + 1
153
+
154
+ return {
155
+ 'total': total,
156
+ 'completed': completed,
157
+ 'errors': errors,
158
+ 'success_rate': f"{(completed/total*100):.1f}%" if total > 0 else "0%",
159
+ 'types': types
160
+ }
161
+
162
+ def format_history_table():
163
+ """Format history as HTML table"""
164
  history = load_history()
165
  if not history:
166
+ return "<p style='text-align:center; color:#666;'>Chưa có lịch sử xử lý nào</p>"
167
 
168
+ html = """
169
+ <style>
170
+ .history-table { width: 100%; border-collapse: collapse; font-size: 14px; }
171
+ .history-table th { background: linear-gradient(90deg, #4682B4, #529AC3); color: white; padding: 12px; text-align: left; font-weight: 600; }
172
+ .history-table td { padding: 10px; border-bottom: 1px solid #ddd; }
173
+ .history-table tr:hover { background-color: #f5f5f5; }
174
+ .status-badge { padding: 4px 10px; border-radius: 12px; font-size: 12px; font-weight: 600; }
175
+ .status-completed { background: #d4edda; color: #155724; }
176
+ .status-error { background: #f8d7da; color: #721c24; }
177
+ .status-processing { background: #fff3cd; color: #856404; }
178
+ .type-badge { padding: 3px 8px; border-radius: 8px; font-size: 11px; font-weight: 600; background: #e3f2fd; color: #1976d2; }
179
+ .action-btn { padding: 5px 12px; margin: 2px; border: none; border-radius: 6px; cursor: pointer; font-size: 12px; font-weight: 600; }
180
+ .btn-download { background: #28a745; color: white; }
181
+ .btn-delete { background: #dc3545; color: white; }
182
+ .btn-download:hover { background: #218838; }
183
+ .btn-delete:hover { background: #c82333; }
184
+ .prompt-text { max-width: 300px; overflow: hidden; text-overflow: ellipsis; white-space: nowrap; }
185
+ .file-count { font-size: 11px; color: #666; margin-top: 3px; }
186
+ </style>
187
+ <table class='history-table'>
188
+ <thead>
189
+ <tr>
190
+ <th style='width: 40px;'>#</th>
191
+ <th style='width: 80px;'>Loại</th>
192
+ <th style='width: 100px;'>Trạng thái</th>
193
+ <th>Prompt</th>
194
+ <th style='width: 100px;'>Files</th>
195
+ <th style='width: 150px;'>Thời gian</th>
196
+ <th style='width: 150px;'>Thao tác</th>
197
+ </tr>
198
+ </thead>
199
+ <tbody>
200
+ """
201
+
202
+ for i, item in enumerate(history[:100], 1):
203
+ status_class = f"status-{item['status']}"
204
+ status_text = "✅ Hoàn thành" if item['status'] == 'completed' else "❌ Lỗi" if item['status'] == 'error' else "⏳ Đang xử lý"
205
+
206
+ type_icons = {'image': '📷', 'video': '🎥', 'click': '👆'}
207
+ type_icon = type_icons.get(item['type'], '📄')
208
+
209
+ prompt = item.get('prompt', 'N/A')[:50] + ('...' if len(item.get('prompt', '')) > 50 else '')
210
+
211
+ # Count files
212
+ file_info = []
213
  if item.get('output_path'):
214
+ file_info.append("Overlay")
215
+ if item.get('segmented_files'):
216
+ file_info.append(f"{len(item['segmented_files'])} Objects")
217
+ if item.get('mask_video_path'):
218
+ file_info.append("Masks")
219
+ if item.get('segmented_video_path'):
220
+ file_info.append("Segmented")
221
+
222
+ files_text = "<br>".join(file_info) if file_info else "N/A"
223
+
224
+ download_btn = ""
225
+ if item.get('output_path') or item.get('segmented_files'):
226
+ download_btn = f"<button class='action-btn btn-download' onclick='downloadFiles(\"{item['id']}\")'>📥 Download</button>"
227
+
228
+ delete_btn = f"<button class='action-btn btn-delete' onclick='deleteHistory(\"{item['id']}\")'>🗑️ Xóa</button>"
229
+
230
+ html += f"""
231
+ <tr>
232
+ <td>{i}</td>
233
+ <td><span class='type-badge'>{type_icon} {item['type'].upper()}</span></td>
234
+ <td><span class='status-badge {status_class}'>{status_text}</span></td>
235
+ <td class='prompt-text' title='{item.get("prompt", "N/A")}'>{prompt}</td>
236
+ <td><div class='file-count'>{files_text}</div></td>
237
+ <td>{item['timestamp']}<br><small>{item.get('duration', '')}</small></td>
238
+ <td>{download_btn}{delete_btn}</td>
239
+ </tr>
240
+ """
241
+
242
+ html += """
243
+ </tbody>
244
+ </table>
245
+ <script>
246
+ function downloadFiles(id) {
247
+ alert('Download functionality: ' + id + '\\nFiles will be packaged as ZIP');
248
+ }
249
+ function deleteHistory(id) {
250
+ if(confirm('Bạn có chắc muốn xóa mục này?')) {
251
+ alert('Deleted: ' + id);
252
+ }
253
+ }
254
+ </script>
255
+ """
256
+
257
+ return html
258
+
259
+ def get_history_gallery():
260
+ """Get recent outputs for gallery display"""
261
+ history = load_history()
262
+ gallery_items = []
263
+
264
+ for item in history[:20]:
265
+ if item['status'] == 'completed' and item.get('output_path'):
266
+ output_path = item['output_path']
267
+ if os.path.exists(output_path):
268
+ caption = f"{item['type'].upper()} | {item['prompt'][:30]}... | {item['timestamp']}"
269
+ gallery_items.append((output_path, caption))
270
+
271
+ return gallery_items
272
+
273
+ def search_history(keyword, filter_type, filter_status):
274
+ """Search and filter history"""
275
+ history = load_history()
276
+ filtered = history
277
+
278
+ if keyword:
279
+ filtered = [h for h in filtered if keyword.lower() in h.get('prompt', '').lower()]
280
+
281
+ if filter_type and filter_type != "all":
282
+ filtered = [h for h in filtered if h['type'] == filter_type]
283
+
284
+ if filter_status and filter_status != "all":
285
+ filtered = [h for h in filtered if h['status'] == filter_status]
286
+
287
+ return filtered
288
+
289
+ def delete_history_item(item_id):
290
+ """Delete a history item and its output file"""
291
+ history = load_history()
292
+ updated_history = []
293
+ deleted = False
294
+
295
+ for item in history:
296
+ if item['id'] == item_id:
297
+ # Delete output file if exists
298
+ if item.get('output_path') and os.path.exists(item['output_path']):
299
+ try:
300
+ os.remove(item['output_path'])
301
+ except:
302
+ pass
303
+ deleted = True
304
+ else:
305
+ updated_history.append(item)
306
+
307
+ if deleted:
308
+ with open(HISTORY_FILE, 'w', encoding='utf-8') as f:
309
+ json.dump(updated_history, f, indent=2, ensure_ascii=False)
310
+ return "✅ Đã xóa thành công"
311
+ return "❌ Không tìm thấy mục cần xóa"
312
+
313
+ def clear_all_history():
314
+ """Clear all history and output files"""
315
+ if os.path.exists(OUTPUTS_DIR):
316
+ shutil.rmtree(OUTPUTS_DIR)
317
+ os.makedirs(OUTPUTS_DIR)
318
+
319
+ with open(HISTORY_FILE, 'w', encoding='utf-8') as f:
320
+ json.dump([], f)
321
+
322
+ return "✅ Đã xóa toàn bộ lịch sử"
323
+
324
+ def export_history_json():
325
+ """Export history as downloadable JSON"""
326
+ history = load_history()
327
+ export_path = os.path.join(HISTORY_DIR, f"history_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
328
+
329
+ with open(export_path, 'w', encoding='utf-8') as f:
330
+ json.dump(history, f, indent=2, ensure_ascii=False)
331
+
332
+ return export_path
333
 
334
  # ============ UTILITY FUNCTIONS ============
335
  def apply_mask_overlay(base_image, mask_data, opacity=0.5):
 
418
  'progress': 100
419
  }
420
 
 
421
  save_history({
422
  'id': job_id,
423
  'type': job_type,
424
  'prompt': job.get('prompt', 'N/A'),
425
  'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
426
  'status': 'completed',
427
+ 'output_path': result.get('output_path'),
428
+ 'segmented_files': result.get('segmented_files', []),
429
+ 'mask_video_path': result.get('mask_video_path'),
430
+ 'segmented_video_path': result.get('segmented_video_path'),
431
+ 'num_objects': result.get('num_objects', 0),
432
+ 'duration': result.get('duration', 'N/A')
433
  })
434
 
435
  except Exception as e:
 
449
  except Exception as e:
450
  print(f"Worker error: {e}")
451
 
 
452
  worker_thread = threading.Thread(target=background_worker, daemon=True)
453
  worker_thread.start()
454
 
 
456
  @spaces.GPU
457
  def process_image_job(job):
458
  """Process image segmentation job"""
459
+ start_time = datetime.now()
460
  source_img = job['image']
461
  text_query = job['prompt']
462
  conf_thresh = job.get('conf_thresh', 0.5)
 
485
  label_str = f"{text_query} ({raw_scores[idx]:.2f})"
486
  annotation_list.append((mask_array, label_str))
487
 
488
+ # Save overlay result
489
+ output_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_overlay.jpg")
490
  result_img = apply_mask_overlay(pil_image, raw_masks)
491
  result_img.save(output_path)
492
 
493
+ # Extract and save individual segmented objects
494
+ segmented_files = []
495
+ for idx, mask_array in enumerate(raw_masks):
496
+ # Create transparent background for segmented object
497
+ mask_bool = mask_array.astype(bool)
498
+
499
+ # Create RGBA image
500
+ segmented = Image.new("RGBA", pil_image.size, (0, 0, 0, 0))
501
+ img_array = np.array(pil_image.convert("RGBA"))
502
+
503
+ # Apply mask
504
+ img_array[~mask_bool] = [0, 0, 0, 0]
505
+ segmented = Image.fromarray(img_array)
506
+
507
+ # Crop to bounding box to save space
508
+ bbox = Image.fromarray(mask_array * 255).getbbox()
509
+ if bbox:
510
+ segmented_cropped = segmented.crop(bbox)
511
+ seg_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_object_{idx+1}.png")
512
+ segmented_cropped.save(seg_path)
513
+ segmented_files.append(seg_path)
514
+
515
+ duration = (datetime.now() - start_time).total_seconds()
516
+
517
  return {
518
  'image': (pil_image, annotation_list),
519
+ 'output_path': output_path,
520
+ 'segmented_files': segmented_files,
521
+ 'num_objects': len(segmented_files),
522
+ 'duration': f"{duration:.2f}s"
523
  }
524
 
525
  @spaces.GPU
526
  def process_video_job(job):
527
  """Process video segmentation job"""
528
+ start_time = datetime.now()
529
  source_vid = job['video']
530
  text_query = job['prompt']
531
  frame_limit = job.get('frame_limit', 60)
 
547
  session = VID_PROCESSOR.init_video_session(video=video_frames, inference_device=device, dtype=torch.bfloat16)
548
  session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=text_query)
549
 
550
+ # Overlay video
551
+ output_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_overlay.mp4")
552
  video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), vid_fps, (vid_w, vid_h))
553
 
554
+ # Mask-only video (black background with white masks)
555
+ mask_video_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_masks_only.mp4")
556
+ mask_writer = cv2.VideoWriter(mask_video_path, cv2.VideoWriter_fourcc(*'mp4v'), vid_fps, (vid_w, vid_h))
557
+
558
+ # Segmented objects video (transparent background)
559
+ segmented_video_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_segmented.mp4")
560
+ segmented_writer = cv2.VideoWriter(segmented_video_path, cv2.VideoWriter_fourcc(*'mp4v'), vid_fps, (vid_w, vid_h))
561
+
562
  total_frames = len(video_frames)
563
  for frame_idx, model_out in enumerate(VID_MODEL.propagate_in_video_iterator(inference_session=session, max_frame_num_to_track=total_frames)):
564
  post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out)
 
568
  if 'masks' in post_processed:
569
  detected_masks = post_processed['masks']
570
  if detected_masks.ndim == 4: detected_masks = detected_masks.squeeze(1)
 
 
 
571
 
572
+ # 1. Overlay frame
573
+ overlay_frame = apply_mask_overlay(original_pil, detected_masks)
574
+ video_writer.write(cv2.cvtColor(np.array(overlay_frame), cv2.COLOR_RGB2BGR))
575
+
576
+ # 2. Mask-only frame (white masks on black background)
577
+ mask_frame = np.zeros((vid_h, vid_w, 3), dtype=np.uint8)
578
+ if isinstance(detected_masks, torch.Tensor):
579
+ detected_masks_np = detected_masks.cpu().numpy()
580
+ else:
581
+ detected_masks_np = detected_masks
582
+
583
+ # Combine all masks
584
+ combined_mask = np.zeros((vid_h, vid_w), dtype=np.uint8)
585
+ for mask in detected_masks_np:
586
+ if mask.shape != (vid_h, vid_w):
587
+ mask = cv2.resize(mask.astype(np.uint8), (vid_w, vid_h), interpolation=cv2.INTER_NEAREST)
588
+ combined_mask = np.maximum(combined_mask, mask)
589
+
590
+ mask_frame[combined_mask > 0] = [255, 255, 255]
591
+ mask_writer.write(mask_frame)
592
+
593
+ # 3. Segmented frame (original with background removed)
594
+ segmented_frame = np.array(original_pil.convert("RGBA"))
595
+ alpha_mask = (combined_mask * 255).astype(np.uint8)
596
+ segmented_frame[:, :, 3] = alpha_mask
597
+
598
+ # Convert to BGR for video (with green screen for transparency)
599
+ bgr_frame = np.zeros((vid_h, vid_w, 3), dtype=np.uint8)
600
+ bgr_frame[:, :] = [0, 255, 0] # Green background
601
+
602
+ for c in range(3):
603
+ bgr_frame[:, :, c] = np.where(
604
+ combined_mask > 0,
605
+ segmented_frame[:, :, 2-c], # RGB to BGR
606
+ bgr_frame[:, :, c]
607
+ )
608
+
609
+ segmented_writer.write(bgr_frame)
610
+ else:
611
+ # No masks detected, write original frames
612
+ video_writer.write(cv2.cvtColor(np.array(original_pil), cv2.COLOR_RGB2BGR))
613
+ mask_writer.write(np.zeros((vid_h, vid_w, 3), dtype=np.uint8))
614
+ segmented_writer.write(cv2.cvtColor(np.array(original_pil), cv2.COLOR_RGB2BGR))
615
 
 
616
  progress = int((frame_idx + 1) / total_frames * 100)
617
  processing_results[job['id']]['progress'] = progress
618
 
619
  video_writer.release()
620
+ mask_writer.release()
621
+ segmented_writer.release()
622
+
623
+ duration = (datetime.now() - start_time).total_seconds()
624
+
625
+ return {
626
+ 'output_path': output_path,
627
+ 'mask_video_path': mask_video_path,
628
+ 'segmented_video_path': segmented_video_path,
629
+ 'duration': f"{duration:.2f}s"
630
+ }
631
 
632
  @spaces.GPU
633
  def process_click_job(job):
634
  """Process click segmentation job"""
635
+ start_time = datetime.now()
636
  input_image = job['image']
637
  points_state = job['points']
638
  labels_state = job['labels']
 
652
  final_img = apply_mask_overlay(input_image, masks[0])
653
  final_img = draw_points_on_image(final_img, points_state)
654
 
655
+ output_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_result.jpg")
656
  final_img.save(output_path)
657
 
658
+ duration = (datetime.now() - start_time).total_seconds()
659
+
660
  return {
661
  'image': final_img,
662
+ 'output_path': output_path,
663
+ 'duration': f"{duration:.2f}s"
664
  }
665
 
666
  # ============ UI HANDLERS ============
667
  def submit_image_job(source_img, text_query, conf_thresh):
 
668
  if source_img is None or not text_query:
669
  return None, "❌ Vui lòng cung cấp ảnh và prompt", ""
670
 
 
681
  return None, f"✅ Đã thêm vào hàng chờ (ID: {job_id[:8]}). Đang xử lý...", job_id
682
 
683
  def check_image_status(job_id):
 
684
  if not job_id or job_id not in processing_results:
685
  return None, "Không tìm thấy công việc"
686
 
 
694
  return None, f"❌ Lỗi: {result.get('error', 'Unknown')}"
695
 
696
  def submit_video_job(source_vid, text_query, frame_limit, time_limit):
 
697
  if not source_vid or not text_query:
698
  return None, "❌ Vui lòng cung cấp video và prompt", ""
699
 
 
711
  return None, f"✅ Đã thêm vào hàng chờ (ID: {job_id[:8]}). Đang xử lý...", job_id
712
 
713
  def check_video_status(job_id):
 
714
  if not job_id or job_id not in processing_results:
715
  return None, "Không tìm thấy công việc"
716
 
 
724
  return None, f"❌ Lỗi: {result.get('error', 'Unknown')}"
725
 
726
  def image_click_handler(image, evt: gr.SelectData, points_state, labels_state):
 
727
  x, y = evt.index
728
 
729
  if points_state is None: points_state = []
 
732
  points_state.append([x, y])
733
  labels_state.append(1)
734
 
 
735
  job_id = str(uuid.uuid4())
736
  job = {
737
  'id': job_id,
 
750
 
751
  # ============ GRADIO INTERFACE ============
752
  custom_css="""
753
+ #col-container { margin: 0 auto; max-width: 1300px; }
754
  #main-title h1 { font-size: 2.1em !important; }
755
+ .stat-card { padding: 20px; border-radius: 12px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; text-align: center; }
756
+ .stat-number { font-size: 2.5em; font-weight: 700; margin: 10px 0; }
757
+ .stat-label { font-size: 1.1em; opacity: 0.9; }
758
  """
759
 
760
  with gr.Blocks(css=custom_css, theme=app_theme) as demo:
 
777
  job_id_img = gr.Textbox(label="Job ID", visible=False)
778
 
779
  with gr.Column(scale=1.5):
780
+ image_result = gr.AnnotatedImage(label="Segmented Result (Overlay)", height=410)
781
  status_img = gr.Textbox(label="Status", interactive=False)
782
+
783
+ with gr.Accordion("📦 Extracted Objects", open=True):
784
+ gr.Markdown("**Các đối tượng được tách ra sẽ hiển thị ở đây:**")
785
+ segmented_gallery = gr.Gallery(
786
+ label="Segmented Objects (PNG with transparent background)",
787
+ columns=3,
788
+ height=300,
789
+ object_fit="contain"
790
+ )
791
+
792
+ def check_and_display_image(job_id):
793
+ """Check status and display both overlay and segmented objects"""
794
+ if not job_id or job_id not in processing_results:
795
+ return None, "Không tìm thấy công việc", []
796
+
797
+ result = processing_results[job_id]
798
+
799
+ if result['status'] == 'processing':
800
+ return None, f"⏳ Đang xử lý... {result['progress']}%", []
801
+ elif result['status'] == 'completed':
802
+ job_result = result['result']
803
+ segmented_files = job_result.get('segmented_files', [])
804
+
805
+ # Create gallery items
806
+ gallery_items = []
807
+ for i, seg_file in enumerate(segmented_files, 1):
808
+ if os.path.exists(seg_file):
809
+ gallery_items.append(seg_file)
810
+
811
+ status_msg = f"✅ Hoàn thành! Đã tách được {len(gallery_items)} đối tượng"
812
+ return job_result['image'], status_msg, gallery_items
813
+ else:
814
+ return None, f"❌ Lỗi: {result.get('error', 'Unknown')}", []
815
 
816
  btn_submit_img.click(
817
  fn=submit_image_job,
 
820
  )
821
 
822
  btn_check_img.click(
823
+ fn=check_and_display_image,
824
  inputs=[job_id_img],
825
+ outputs=[image_result, status_img, segmented_gallery]
826
  )
827
 
828
  # ===== VIDEO SEGMENTATION TAB =====
 
841
  job_id_vid = gr.Textbox(label="Job ID", visible=False)
842
 
843
  with gr.Column():
844
+ gr.Markdown("### 📹 Video Outputs")
845
+
846
+ with gr.Tabs():
847
+ with gr.Tab("Overlay"):
848
+ video_result_overlay = gr.Video(label="1. Overlay (Original + Masks)")
849
+
850
+ with gr.Tab("Masks Only"):
851
+ video_result_masks = gr.Video(label="2. Masks Only (White on Black)")
852
+
853
+ with gr.Tab("Segmented"):
854
+ video_result_segmented = gr.Video(label="3. Segmented (Green Screen Background)")
855
+
856
  status_vid = gr.Textbox(label="Status", interactive=False)
857
+
858
+ def check_and_display_video(job_id):
859
+ """Check status and display all video outputs"""
860
+ if not job_id or job_id not in processing_results:
861
+ return None, None, None, "Không tìm thấy công việc"
862
+
863
+ result = processing_results[job_id]
864
+
865
+ if result['status'] == 'processing':
866
+ status = f"⏳ Đang xử lý... {result['progress']}%"
867
+ return None, None, None, status
868
+ elif result['status'] == 'completed':
869
+ job_result = result['result']
870
+ overlay = job_result.get('output_path')
871
+ masks = job_result.get('mask_video_path')
872
+ segmented = job_result.get('segmented_video_path')
873
+
874
+ status = "✅ Hoàn thành! 3 video đã được tạo:\n"
875
+ status += "1️⃣ Overlay - Ảnh gốc với mask màu\n"
876
+ status += "2️⃣ Masks Only - Chỉ mask (trắng/đen)\n"
877
+ status += "3️⃣ Segmented - Đối tượng với green screen"
878
+
879
+ return overlay, masks, segmented, status
880
+ else:
881
+ error_msg = f"❌ Lỗi: {result.get('error', 'Unknown')}"
882
+ return None, None, None, error_msg
883
 
884
  btn_submit_vid.click(
885
  fn=submit_video_job,
886
  inputs=[video_input, txt_prompt_vid, frame_limiter, time_limiter],
887
+ outputs=[video_result_overlay, status_vid, job_id_vid]
888
  )
889
 
890
  btn_check_vid.click(
891
+ fn=check_and_display_video,
892
  inputs=[job_id_vid],
893
+ outputs=[video_result_overlay, video_result_masks, video_result_segmented, status_vid]
894
  )
895
 
896
  # ===== CLICK SEGMENTATION TAB =====
 
920
  outputs=[img_click_output, st_click_points, st_click_labels]
921
  )
922
 
923
+ # ===== ADVANCED HISTORY TAB =====
924
+ with gr.Tab("📊 Lịch Sử & Thống Kê"):
925
+ with gr.Row():
926
+ # Statistics Dashboard
927
+ with gr.Column(scale=1):
928
+ gr.Markdown("### 📈 Thống Kê Tổng Quan")
929
+
930
+ def update_stats():
931
+ stats = get_history_stats()
932
+ return (
933
+ f"**{stats['total']}** Tổng số",
934
+ f"**{stats['completed']}** Hoàn thành",
935
+ f"**{stats['errors']}** Lỗi",
936
+ f"**{stats['success_rate']}** Tỷ lệ thành công"
937
+ )
938
+
939
+ with gr.Row():
940
+ stat_total = gr.Markdown("**0** Tổng số")
941
+ stat_completed = gr.Markdown("**0** Hoàn thành")
942
+ with gr.Row():
943
+ stat_errors = gr.Markdown("**0** Lỗi")
944
+ stat_success = gr.Markdown("**0%** Tỷ lệ thành công")
945
+
946
+ gr.Markdown("### 🎯 Thao Tác Nhanh")
947
+ with gr.Row():
948
+ btn_refresh = gr.Button("🔄 Refresh", variant="primary", scale=1)
949
+ btn_export = gr.Button("📥 Export JSON", variant="secondary", scale=1)
950
+ with gr.Row():
951
+ btn_clear_all = gr.Button("🗑️ Clear All History", variant="stop", scale=1)
952
+
953
+ export_file = gr.File(label="Exported File", visible=False)
954
+ clear_status = gr.Textbox(label="Status", interactive=False)
955
+
956
+ # History Table
957
  with gr.Row():
958
  with gr.Column():
959
+ gr.Markdown("### 📜 Lịch Sử Chi Tiết")
960
+
961
+ # Search and Filter
962
+ with gr.Row():
963
+ search_input = gr.Textbox(
964
+ placeholder="🔍 Tìm kiếm theo prompt...",
965
+ label="Search",
966
+ scale=2
967
+ )
968
+ filter_type = gr.Dropdown(
969
+ choices=["all", "image", "video", "click"],
970
+ value="all",
971
+ label="Loại",
972
+ scale=1
973
+ )
974
+ filter_status = gr.Dropdown(
975
+ choices=["all", "completed", "error"],
976
+ value="all",
977
+ label="Trạng thái",
978
+ scale=1
979
+ )
980
 
981
+ history_table = gr.HTML(value=format_history_table())
982
+
983
+ # Gallery View
984
+ with gr.Row():
985
+ with gr.Column():
986
+ gr.Markdown("### 🖼️ Gallery - Kết Quả Gần Đây")
987
+ history_gallery = gr.Gallery(
988
+ value=get_history_gallery(),
989
+ label="Recent Outputs",
990
+ columns=4,
991
+ height=400,
992
+ object_fit="contain"
993
+ )
994
+
995
+ # Event handlers
996
+ def refresh_all():
997
+ return (
998
+ *update_stats(),
999
+ format_history_table(),
1000
+ get_history_gallery()
1001
+ )
1002
 
1003
+ btn_refresh.click(
1004
+ fn=refresh_all,
1005
+ outputs=[stat_total, stat_completed, stat_errors, stat_success, history_table, history_gallery]
1006
+ )
1007
+
1008
+ btn_export.click(
1009
+ fn=export_history_json,
1010
+ outputs=[export_file]
1011
+ )
1012
+
1013
+ btn_clear_all.click(
1014
+ fn=clear_all_history,
1015
+ outputs=[clear_status]
1016
+ ).then(
1017
+ fn=refresh_all,
1018
+ outputs=[stat_total, stat_completed, stat_errors, stat_success, history_table, history_gallery]
1019
+ )
1020
+
1021
+ # Auto-refresh when searching/filtering
1022
+ def filter_and_display(keyword, ftype, fstatus):
1023
+ filtered = search_history(keyword, ftype, fstatus)
1024
+ # Format filtered results
1025
+ if not filtered:
1026
+ return "<p style='text-align:center; color:#666;'>Không tìm thấy kết quả</p>"
1027
+
1028
+ # Reuse formatting logic
1029
+ html = format_history_table()
1030
+ return html
1031
+
1032
+ search_input.change(
1033
+ fn=filter_and_display,
1034
+ inputs=[search_input, filter_type, filter_status],
1035
+ outputs=[history_table]
1036
+ )
1037
+ filter_type.change(
1038
+ fn=filter_and_display,
1039
+ inputs=[search_input, filter_type, filter_status],
1040
+ outputs=[history_table]
1041
+ )
1042
+ filter_status.change(
1043
+ fn=filter_and_display,
1044
+ inputs=[search_input, filter_type, filter_status],
1045
+ outputs=[history_table]
1046
  )
 
 
 
 
 
 
 
 
 
 
1047
 
1048
  if __name__ == "__main__":
1049
  demo.launch(