Upload app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
import os
|
| 2 |
import cv2
|
|
|
|
| 3 |
import spaces
|
| 4 |
import gradio as gr
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
import matplotlib
|
|
|
|
| 8 |
from PIL import Image, ImageDraw
|
| 9 |
from typing import Iterable
|
| 10 |
from gradio.themes import Soft
|
|
@@ -19,24 +21,46 @@ from datetime import datetime
|
|
| 19 |
import threading
|
| 20 |
import queue
|
| 21 |
import uuid
|
| 22 |
-
import shutil
|
| 23 |
-
import zipfile
|
| 24 |
|
| 25 |
# ============ THEME SETUP ============
|
| 26 |
colors.steel_blue = colors.Color(
|
| 27 |
name="steel_blue",
|
| 28 |
-
c50="#EBF3F8",
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
)
|
| 32 |
|
| 33 |
class CustomBlueTheme(Soft):
|
| 34 |
-
def __init__(
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
super().set(
|
| 41 |
background_fill_primary="*primary_50",
|
| 42 |
background_fill_primary_dark="*primary_900",
|
|
@@ -62,255 +86,76 @@ class CustomBlueTheme(Soft):
|
|
| 62 |
app_theme = CustomBlueTheme()
|
| 63 |
|
| 64 |
# ============ GLOBAL SETUP ============
|
| 65 |
-
device = "
|
| 66 |
-
print(f"🖥️ Using device: {device}
|
| 67 |
|
|
|
|
| 68 |
HISTORY_DIR = "processing_history"
|
| 69 |
-
|
| 70 |
-
DOWNLOADS_DIR = os.path.join(HISTORY_DIR, "downloads")
|
| 71 |
-
os.makedirs(OUTPUTS_DIR, exist_ok=True)
|
| 72 |
-
os.makedirs(DOWNLOADS_DIR, exist_ok=True)
|
| 73 |
HISTORY_FILE = os.path.join(HISTORY_DIR, "history.json")
|
| 74 |
|
|
|
|
| 75 |
processing_queue = queue.Queue()
|
| 76 |
processing_results = {}
|
| 77 |
|
| 78 |
# Load models
|
| 79 |
-
print("⏳ Loading SAM3 Models...")
|
| 80 |
try:
|
| 81 |
-
print(" Loading Image Model
|
| 82 |
IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3").to(device)
|
| 83 |
IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3")
|
| 84 |
-
|
| 85 |
-
print(" Loading Tracker Model
|
| 86 |
TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3").to(device)
|
| 87 |
TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3")
|
| 88 |
-
|
| 89 |
-
print(" Loading Video Model
|
| 90 |
-
VID_MODEL = Sam3VideoModel.from_pretrained("DiffusionWave/sam3").to(device
|
| 91 |
VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("DiffusionWave/sam3")
|
| 92 |
|
| 93 |
-
print("✅ All
|
| 94 |
except Exception as e:
|
| 95 |
-
print(f"❌
|
| 96 |
IMG_MODEL = IMG_PROCESSOR = TRK_MODEL = TRK_PROCESSOR = VID_MODEL = VID_PROCESSOR = None
|
| 97 |
|
| 98 |
# ============ HISTORY MANAGEMENT ============
|
| 99 |
def load_history():
|
|
|
|
| 100 |
if os.path.exists(HISTORY_FILE):
|
| 101 |
try:
|
| 102 |
-
with open(HISTORY_FILE, 'r'
|
| 103 |
return json.load(f)
|
| 104 |
except:
|
| 105 |
return []
|
| 106 |
return []
|
| 107 |
|
| 108 |
-
def save_history(
|
| 109 |
-
history
|
| 110 |
-
history.insert(0, item)
|
| 111 |
-
history = history[:200]
|
| 112 |
-
with open(HISTORY_FILE, 'w', encoding='utf-8') as f:
|
| 113 |
-
json.dump(history, f, indent=2, ensure_ascii=False)
|
| 114 |
-
|
| 115 |
-
def get_history_stats():
|
| 116 |
-
history = load_history()
|
| 117 |
-
total = len(history)
|
| 118 |
-
completed = sum(1 for h in history if h['status'] == 'completed')
|
| 119 |
-
errors = sum(1 for h in history if h['status'] == 'error')
|
| 120 |
-
types = {}
|
| 121 |
-
for h in history:
|
| 122 |
-
t = h['type']
|
| 123 |
-
types[t] = types.get(t, 0) + 1
|
| 124 |
-
return {
|
| 125 |
-
'total': total,
|
| 126 |
-
'completed': completed,
|
| 127 |
-
'errors': errors,
|
| 128 |
-
'success_rate': f"{(completed/total*100):.1f}%" if total > 0 else "0%",
|
| 129 |
-
'types': types
|
| 130 |
-
}
|
| 131 |
-
|
| 132 |
-
def create_download_package(item_id):
|
| 133 |
-
history = load_history()
|
| 134 |
-
item = next((h for h in history if h['id'] == item_id), None)
|
| 135 |
-
|
| 136 |
-
if not item or item['status'] != 'completed':
|
| 137 |
-
return None
|
| 138 |
-
|
| 139 |
-
zip_path = os.path.join(DOWNLOADS_DIR, f"{item_id}_results.zip")
|
| 140 |
-
|
| 141 |
-
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 142 |
-
metadata = {
|
| 143 |
-
'job_id': item_id,
|
| 144 |
-
'type': item['type'],
|
| 145 |
-
'prompt': item.get('prompt', 'N/A'),
|
| 146 |
-
'timestamp': item['timestamp'],
|
| 147 |
-
'duration': item.get('duration', 'N/A'),
|
| 148 |
-
'num_objects': item.get('num_objects', 0)
|
| 149 |
-
}
|
| 150 |
-
zipf.writestr('metadata.json', json.dumps(metadata, indent=2, ensure_ascii=False))
|
| 151 |
-
|
| 152 |
-
if item['type'] == 'image':
|
| 153 |
-
if item.get('output_path') and os.path.exists(item['output_path']):
|
| 154 |
-
zipf.write(item['output_path'], 'overlay.jpg')
|
| 155 |
-
if item.get('segmented_files'):
|
| 156 |
-
for i, f in enumerate(item['segmented_files'], 1):
|
| 157 |
-
if os.path.exists(f):
|
| 158 |
-
zipf.write(f, f'objects/object_{i}.png')
|
| 159 |
-
|
| 160 |
-
elif item['type'] == 'video':
|
| 161 |
-
if item.get('output_path') and os.path.exists(item['output_path']):
|
| 162 |
-
zipf.write(item['output_path'], 'overlay_video.mp4')
|
| 163 |
-
if item.get('mask_video_path') and os.path.exists(item['mask_video_path']):
|
| 164 |
-
zipf.write(item['mask_video_path'], 'masks_only.mp4')
|
| 165 |
-
if item.get('segmented_video_path') and os.path.exists(item['segmented_video_path']):
|
| 166 |
-
zipf.write(item['segmented_video_path'], 'segmented_video.mp4')
|
| 167 |
-
|
| 168 |
-
elif item['type'] == 'click':
|
| 169 |
-
if item.get('output_path') and os.path.exists(item['output_path']):
|
| 170 |
-
zipf.write(item['output_path'], 'result.jpg')
|
| 171 |
-
|
| 172 |
-
return zip_path
|
| 173 |
-
|
| 174 |
-
def get_downloadable_jobs():
|
| 175 |
history = load_history()
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
label = f"{type_emoji} [{item['type'].upper()}] {item['prompt'][:35]}... | {item['timestamp']}"
|
| 181 |
-
choices.append((label, item['id']))
|
| 182 |
-
return choices if choices else [("No completed jobs available", None)]
|
| 183 |
|
| 184 |
-
def
|
|
|
|
| 185 |
history = load_history()
|
| 186 |
if not history:
|
| 187 |
-
return "
|
| 188 |
-
|
| 189 |
-
html = """
|
| 190 |
-
<style>
|
| 191 |
-
.history-table { width: 100%; border-collapse: collapse; font-size: 14px; background: white; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
|
| 192 |
-
.history-table th { background: linear-gradient(90deg, #4682B4, #529AC3); color: white; padding: 14px 12px; text-align: left; font-weight: 600; text-transform: uppercase; font-size: 12px; letter-spacing: 0.5px; }
|
| 193 |
-
.history-table td { padding: 12px; border-bottom: 1px solid #e8e8e8; vertical-align: middle; }
|
| 194 |
-
.history-table tr:hover { background-color: #f8f9fa; }
|
| 195 |
-
.history-table tr:last-child td { border-bottom: none; }
|
| 196 |
-
.status-badge { padding: 5px 12px; border-radius: 14px; font-size: 11px; font-weight: 700; text-transform: uppercase; letter-spacing: 0.5px; display: inline-block; }
|
| 197 |
-
.status-completed { background: linear-gradient(135deg, #d4edda, #c3e6cb); color: #155724; }
|
| 198 |
-
.status-error { background: linear-gradient(135deg, #f8d7da, #f5c6cb); color: #721c24; }
|
| 199 |
-
.type-badge { padding: 5px 10px; border-radius: 10px; font-size: 11px; font-weight: 600; background: linear-gradient(135deg, #e3f2fd, #bbdefb); color: #1565c0; display: inline-block; }
|
| 200 |
-
.prompt-text { max-width: 280px; overflow: hidden; text-overflow: ellipsis; white-space: nowrap; color: #333; font-weight: 500; }
|
| 201 |
-
.file-count { font-size: 11px; color: #666; margin-top: 4px; line-height: 1.4; }
|
| 202 |
-
.job-id { font-family: 'Courier New', monospace; font-size: 10px; color: #999; background: #f5f5f5; padding: 3px 6px; border-radius: 4px; }
|
| 203 |
-
.time-info { font-size: 12px; color: #666; }
|
| 204 |
-
.duration { font-size: 11px; color: #999; margin-top: 3px; }
|
| 205 |
-
</style>
|
| 206 |
-
<table class='history-table'>
|
| 207 |
-
<thead>
|
| 208 |
-
<tr>
|
| 209 |
-
<th style='width: 40px; text-align: center;'>#</th>
|
| 210 |
-
<th style='width: 110px;'>Job ID</th>
|
| 211 |
-
<th style='width: 90px;'>Type</th>
|
| 212 |
-
<th style='width: 110px;'>Status</th>
|
| 213 |
-
<th>Prompt</th>
|
| 214 |
-
<th style='width: 120px;'>Output Files</th>
|
| 215 |
-
<th style='width: 140px;'>Time</th>
|
| 216 |
-
</tr>
|
| 217 |
-
</thead>
|
| 218 |
-
<tbody>
|
| 219 |
-
"""
|
| 220 |
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
type_icon = type_icons.get(item['type'], '📄')
|
| 227 |
-
|
| 228 |
-
prompt = item.get('prompt', 'N/A')
|
| 229 |
-
prompt_short = prompt[:45] + ('...' if len(prompt) > 45 else '')
|
| 230 |
-
|
| 231 |
-
file_info = []
|
| 232 |
if item.get('output_path'):
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
if item.get('mask_video_path'):
|
| 237 |
-
file_info.append("✓ Masks")
|
| 238 |
-
if item.get('segmented_video_path'):
|
| 239 |
-
file_info.append("✓ Segmented")
|
| 240 |
-
|
| 241 |
-
files_text = "<br>".join(file_info) if file_info else "No files"
|
| 242 |
-
|
| 243 |
-
html += f"""
|
| 244 |
-
<tr>
|
| 245 |
-
<td style='text-align: center; font-weight: 600; color: #999;'>{i}</td>
|
| 246 |
-
<td><span class='job-id'>{item['id'][:12]}</span></td>
|
| 247 |
-
<td><span class='type-badge'>{type_icon} {item['type'].upper()}</span></td>
|
| 248 |
-
<td><span class='status-badge {status_class}'>{status_text}</span></td>
|
| 249 |
-
<td class='prompt-text' title='{prompt}'>{prompt_short}</td>
|
| 250 |
-
<td><div class='file-count'>{files_text}</div></td>
|
| 251 |
-
<td>
|
| 252 |
-
<div class='time-info'>{item['timestamp']}</div>
|
| 253 |
-
<div class='duration'>⏱️ {item.get('duration', 'N/A')}</div>
|
| 254 |
-
</td>
|
| 255 |
-
</tr>
|
| 256 |
-
"""
|
| 257 |
-
|
| 258 |
-
html += """
|
| 259 |
-
</tbody>
|
| 260 |
-
</table>
|
| 261 |
-
"""
|
| 262 |
-
|
| 263 |
-
return html
|
| 264 |
|
| 265 |
-
|
| 266 |
-
history = load_history()
|
| 267 |
-
gallery_items = []
|
| 268 |
-
|
| 269 |
-
for item in history[:30]:
|
| 270 |
-
if item['status'] == 'completed':
|
| 271 |
-
if item.get('output_path') and os.path.exists(item['output_path']):
|
| 272 |
-
caption = f"[{item['type'].upper()}] {item['prompt'][:35]}... | {item['timestamp']}"
|
| 273 |
-
gallery_items.append((item['output_path'], caption))
|
| 274 |
-
|
| 275 |
-
return gallery_items if gallery_items else []
|
| 276 |
-
|
| 277 |
-
def search_history(keyword, filter_type, filter_status):
|
| 278 |
-
history = load_history()
|
| 279 |
-
filtered = history
|
| 280 |
-
|
| 281 |
-
if keyword:
|
| 282 |
-
filtered = [h for h in filtered if keyword.lower() in h.get('prompt', '').lower()]
|
| 283 |
-
if filter_type and filter_type != "all":
|
| 284 |
-
filtered = [h for h in filtered if h['type'] == filter_type]
|
| 285 |
-
if filter_status and filter_status != "all":
|
| 286 |
-
filtered = [h for h in filtered if h['status'] == filter_status]
|
| 287 |
-
|
| 288 |
-
return filtered
|
| 289 |
-
|
| 290 |
-
def clear_all_history():
|
| 291 |
-
if os.path.exists(OUTPUTS_DIR):
|
| 292 |
-
shutil.rmtree(OUTPUTS_DIR)
|
| 293 |
-
os.makedirs(OUTPUTS_DIR)
|
| 294 |
-
if os.path.exists(DOWNLOADS_DIR):
|
| 295 |
-
shutil.rmtree(DOWNLOADS_DIR)
|
| 296 |
-
os.makedirs(DOWNLOADS_DIR)
|
| 297 |
-
|
| 298 |
-
with open(HISTORY_FILE, 'w', encoding='utf-8') as f:
|
| 299 |
-
json.dump([], f)
|
| 300 |
-
|
| 301 |
-
return "✅ Đã xóa toàn bộ lịch sử và files"
|
| 302 |
-
|
| 303 |
-
def export_history_json():
|
| 304 |
-
history = load_history()
|
| 305 |
-
export_path = os.path.join(HISTORY_DIR, f"history_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
|
| 306 |
-
|
| 307 |
-
with open(export_path, 'w', encoding='utf-8') as f:
|
| 308 |
-
json.dump(history, f, indent=2, ensure_ascii=False)
|
| 309 |
-
|
| 310 |
-
return export_path
|
| 311 |
-
|
| 312 |
-
# ============ PROCESSING UTILS ============
|
| 313 |
def apply_mask_overlay(base_image, mask_data, opacity=0.5):
|
|
|
|
| 314 |
if isinstance(base_image, np.ndarray):
|
| 315 |
base_image = Image.fromarray(base_image)
|
| 316 |
base_image = base_image.convert("RGBA")
|
|
@@ -322,10 +167,8 @@ def apply_mask_overlay(base_image, mask_data, opacity=0.5):
|
|
| 322 |
mask_data = mask_data.cpu().numpy()
|
| 323 |
mask_data = mask_data.astype(np.uint8)
|
| 324 |
|
| 325 |
-
if mask_data.ndim == 4:
|
| 326 |
-
|
| 327 |
-
if mask_data.ndim == 3 and mask_data.shape[0] == 1:
|
| 328 |
-
mask_data = mask_data[0]
|
| 329 |
|
| 330 |
num_masks = mask_data.shape[0] if mask_data.ndim == 3 else 1
|
| 331 |
if mask_data.ndim == 2:
|
|
@@ -334,262 +177,44 @@ def apply_mask_overlay(base_image, mask_data, opacity=0.5):
|
|
| 334 |
|
| 335 |
try:
|
| 336 |
color_map = matplotlib.colormaps["rainbow"].resampled(max(num_masks, 1))
|
| 337 |
-
except:
|
| 338 |
import matplotlib.cm as cm
|
| 339 |
color_map = cm.get_cmap("rainbow").resampled(max(num_masks, 1))
|
| 340 |
|
| 341 |
rgb_colors = [tuple(int(c * 255) for c in color_map(i)[:3]) for i in range(num_masks)]
|
| 342 |
composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0))
|
| 343 |
|
| 344 |
-
for i,
|
| 345 |
-
|
| 346 |
-
if
|
| 347 |
-
|
| 348 |
|
| 349 |
-
|
| 350 |
-
|
|
|
|
| 351 |
color_fill.putalpha(mask_alpha)
|
| 352 |
composite_layer = Image.alpha_composite(composite_layer, color_fill)
|
| 353 |
|
| 354 |
return Image.alpha_composite(base_image, composite_layer).convert("RGB")
|
| 355 |
|
| 356 |
def draw_points_on_image(image, points):
|
|
|
|
| 357 |
if isinstance(image, np.ndarray):
|
| 358 |
image = Image.fromarray(image)
|
|
|
|
| 359 |
draw_img = image.copy()
|
| 360 |
draw = ImageDraw.Draw(draw_img)
|
| 361 |
-
|
|
|
|
|
|
|
| 362 |
r = 8
|
| 363 |
draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white", width=4)
|
| 364 |
-
return draw_img
|
| 365 |
-
|
| 366 |
-
# ============ JOB PROCESSORS ============
|
| 367 |
-
def process_image_job(job):
|
| 368 |
-
start = datetime.now()
|
| 369 |
-
img = job['image']
|
| 370 |
-
if isinstance(img, str):
|
| 371 |
-
img = Image.open(img)
|
| 372 |
-
|
| 373 |
-
img = img.convert("RGB")
|
| 374 |
-
inputs = IMG_PROCESSOR(images=img, text=job['prompt'], return_tensors="pt").to(device)
|
| 375 |
-
|
| 376 |
-
with torch.no_grad():
|
| 377 |
-
outputs = IMG_MODEL(**inputs)
|
| 378 |
-
|
| 379 |
-
results = IMG_PROCESSOR.post_process_instance_segmentation(
|
| 380 |
-
outputs,
|
| 381 |
-
threshold=job.get('conf_thresh', 0.5),
|
| 382 |
-
mask_threshold=0.5,
|
| 383 |
-
target_sizes=inputs.get("original_sizes").tolist()
|
| 384 |
-
)[0]
|
| 385 |
-
|
| 386 |
-
masks = results['masks'].cpu().numpy()
|
| 387 |
-
scores = results['scores'].cpu().numpy()
|
| 388 |
-
annotations = [(m, f"{job['prompt']} ({s:.2f})") for m, s in zip(masks, scores)]
|
| 389 |
-
|
| 390 |
-
out_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_overlay.jpg")
|
| 391 |
-
apply_mask_overlay(img, masks).save(out_path)
|
| 392 |
-
|
| 393 |
-
seg_files = []
|
| 394 |
-
for i, mask in enumerate(masks):
|
| 395 |
-
mask_bool = mask.astype(bool)
|
| 396 |
-
seg = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
| 397 |
-
arr = np.array(img.convert("RGBA"))
|
| 398 |
-
arr[~mask_bool] = [0, 0, 0, 0]
|
| 399 |
-
seg = Image.fromarray(arr)
|
| 400 |
-
|
| 401 |
-
# Fix: Convert mask to uint8 before creating Image
|
| 402 |
-
mask_uint8 = (mask * 255).astype(np.uint8)
|
| 403 |
-
bbox = Image.fromarray(mask_uint8).getbbox()
|
| 404 |
-
|
| 405 |
-
if bbox:
|
| 406 |
-
seg_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_obj_{i+1}.png")
|
| 407 |
-
seg.crop(bbox).save(seg_path)
|
| 408 |
-
seg_files.append(seg_path)
|
| 409 |
-
|
| 410 |
-
return {
|
| 411 |
-
'image': (img, annotations),
|
| 412 |
-
'output_path': out_path,
|
| 413 |
-
'segmented_files': seg_files,
|
| 414 |
-
'num_objects': len(seg_files),
|
| 415 |
-
'duration': f"{(datetime.now() - start).total_seconds():.2f}s"
|
| 416 |
-
}
|
| 417 |
-
|
| 418 |
-
def process_video_job(job):
|
| 419 |
-
"""Process video on CPU - slower but no timeout"""
|
| 420 |
-
start = datetime.now()
|
| 421 |
-
cap = cv2.VideoCapture(job['video'])
|
| 422 |
-
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 423 |
-
w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 424 |
-
|
| 425 |
-
frames = []
|
| 426 |
-
limit = job.get('frame_limit', 60)
|
| 427 |
-
if limit == 0 or limit > 500:
|
| 428 |
-
limit = 500 # Higher limit for CPU since no GPU timeout
|
| 429 |
-
|
| 430 |
-
count = 0
|
| 431 |
-
while cap.isOpened():
|
| 432 |
-
ret, frame = cap.read()
|
| 433 |
-
if not ret or count >= limit:
|
| 434 |
-
break
|
| 435 |
-
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 436 |
-
count += 1
|
| 437 |
-
cap.release()
|
| 438 |
-
|
| 439 |
-
print(f"📹 Processing {len(frames)} frames on CPU (this will take longer)...")
|
| 440 |
-
|
| 441 |
-
# Process in chunks to manage memory
|
| 442 |
-
chunk_size = 30 # Smaller chunks for CPU
|
| 443 |
-
|
| 444 |
-
out_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_overlay.mp4")
|
| 445 |
-
mask_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_masks.mp4")
|
| 446 |
-
seg_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_segmented.mp4")
|
| 447 |
-
|
| 448 |
-
writers = [
|
| 449 |
-
cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)),
|
| 450 |
-
cv2.VideoWriter(mask_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)),
|
| 451 |
-
cv2.VideoWriter(seg_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
| 452 |
-
]
|
| 453 |
-
|
| 454 |
-
total = len(frames)
|
| 455 |
-
processed = 0
|
| 456 |
-
|
| 457 |
-
# Process frames in chunks
|
| 458 |
-
for chunk_start in range(0, total, chunk_size):
|
| 459 |
-
chunk_end = min(chunk_start + chunk_size, total)
|
| 460 |
-
chunk_frames = frames[chunk_start:chunk_end]
|
| 461 |
-
|
| 462 |
-
print(f"🔄 Processing chunk {chunk_start}-{chunk_end} ({len(chunk_frames)} frames)")
|
| 463 |
-
|
| 464 |
-
try:
|
| 465 |
-
# Initialize session for this chunk
|
| 466 |
-
session = VID_PROCESSOR.init_video_session(
|
| 467 |
-
video=chunk_frames,
|
| 468 |
-
inference_device=device
|
| 469 |
-
)
|
| 470 |
-
session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=job['prompt'])
|
| 471 |
-
|
| 472 |
-
# Process chunk
|
| 473 |
-
for idx, out in enumerate(VID_MODEL.propagate_in_video_iterator(
|
| 474 |
-
inference_session=session,
|
| 475 |
-
max_frame_num_to_track=len(chunk_frames)
|
| 476 |
-
)):
|
| 477 |
-
try:
|
| 478 |
-
proc = VID_PROCESSOR.postprocess_outputs(session, out)
|
| 479 |
-
f_idx = out.frame_idx
|
| 480 |
-
orig = Image.fromarray(chunk_frames[f_idx])
|
| 481 |
-
|
| 482 |
-
if 'masks' in proc:
|
| 483 |
-
masks = proc['masks']
|
| 484 |
-
if masks.ndim == 4:
|
| 485 |
-
masks = masks.squeeze(1)
|
| 486 |
-
|
| 487 |
-
overlay = apply_mask_overlay(orig, masks)
|
| 488 |
-
writers[0].write(cv2.cvtColor(np.array(overlay), cv2.COLOR_RGB2BGR))
|
| 489 |
-
|
| 490 |
-
mask_np = masks.cpu().numpy() if isinstance(masks, torch.Tensor) else masks
|
| 491 |
-
combined = np.zeros((h, w), dtype=np.uint8)
|
| 492 |
-
for m in mask_np:
|
| 493 |
-
if m.shape != (h, w):
|
| 494 |
-
m = cv2.resize(m.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST)
|
| 495 |
-
combined = np.maximum(combined, m)
|
| 496 |
-
|
| 497 |
-
mask_frame = np.zeros((h, w, 3), dtype=np.uint8)
|
| 498 |
-
mask_frame[combined > 0] = [255, 255, 255]
|
| 499 |
-
writers[1].write(mask_frame)
|
| 500 |
-
|
| 501 |
-
seg_arr = np.array(orig.convert("RGBA"))
|
| 502 |
-
seg_arr[:, :, 3] = (combined * 255).astype(np.uint8)
|
| 503 |
-
bgr = np.zeros((h, w, 3), dtype=np.uint8)
|
| 504 |
-
bgr[:, :] = [0, 255, 0]
|
| 505 |
-
for c in range(3):
|
| 506 |
-
bgr[:, :, c] = np.where(combined > 0, seg_arr[:, :, 2-c], bgr[:, :, c])
|
| 507 |
-
writers[2].write(bgr)
|
| 508 |
-
else:
|
| 509 |
-
orig_bgr = cv2.cvtColor(np.array(orig), cv2.COLOR_RGB2BGR)
|
| 510 |
-
writers[0].write(orig_bgr)
|
| 511 |
-
writers[1].write(np.zeros((h, w, 3), dtype=np.uint8))
|
| 512 |
-
writers[2].write(orig_bgr)
|
| 513 |
-
|
| 514 |
-
processed += 1
|
| 515 |
-
progress = int((processed / total) * 100)
|
| 516 |
-
processing_results[job['id']]['progress'] = progress
|
| 517 |
-
|
| 518 |
-
if processed % 5 == 0:
|
| 519 |
-
elapsed = (datetime.now() - start).total_seconds()
|
| 520 |
-
avg_time = elapsed / processed
|
| 521 |
-
remaining = (total - processed) * avg_time
|
| 522 |
-
print(f"⏳ Progress: {progress}% ({processed}/{total}) | ETA: {remaining/60:.1f} min")
|
| 523 |
-
|
| 524 |
-
except Exception as e:
|
| 525 |
-
print(f"⚠️ Error processing frame {f_idx}: {e}")
|
| 526 |
-
orig_bgr = cv2.cvtColor(np.array(orig), cv2.COLOR_RGB2BGR)
|
| 527 |
-
writers[0].write(orig_bgr)
|
| 528 |
-
writers[1].write(np.zeros((h, w, 3), dtype=np.uint8))
|
| 529 |
-
writers[2].write(orig_bgr)
|
| 530 |
-
processed += 1
|
| 531 |
-
|
| 532 |
-
# Clear memory after each chunk
|
| 533 |
-
del session
|
| 534 |
-
|
| 535 |
-
except Exception as e:
|
| 536 |
-
print(f"❌ Error processing chunk: {e}")
|
| 537 |
-
for i in range(chunk_start, chunk_end):
|
| 538 |
-
if i < len(frames):
|
| 539 |
-
orig_bgr = cv2.cvtColor(frames[i], cv2.COLOR_RGB2BGR)
|
| 540 |
-
writers[0].write(orig_bgr)
|
| 541 |
-
writers[1].write(np.zeros((h, w, 3), dtype=np.uint8))
|
| 542 |
-
writers[2].write(orig_bgr)
|
| 543 |
-
processed += 1
|
| 544 |
-
|
| 545 |
-
for w in writers:
|
| 546 |
-
w.release()
|
| 547 |
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
return {
|
| 551 |
-
'output_path': out_path,
|
| 552 |
-
'mask_video_path': mask_path,
|
| 553 |
-
'segmented_video_path': seg_path,
|
| 554 |
-
'duration': f"{(datetime.now() - start).total_seconds():.2f}s"
|
| 555 |
-
}
|
| 556 |
-
|
| 557 |
-
def process_click_job(job):
|
| 558 |
-
start = datetime.now()
|
| 559 |
-
img = job['image']
|
| 560 |
-
if isinstance(img, str):
|
| 561 |
-
img = Image.open(img)
|
| 562 |
-
|
| 563 |
-
inputs = TRK_PROCESSOR(
|
| 564 |
-
images=img,
|
| 565 |
-
input_points=[[job['points']]],
|
| 566 |
-
input_labels=[[job['labels']]],
|
| 567 |
-
return_tensors="pt"
|
| 568 |
-
).to(device)
|
| 569 |
-
|
| 570 |
-
with torch.no_grad():
|
| 571 |
-
outputs = TRK_MODEL(**inputs, multimask_output=False)
|
| 572 |
-
|
| 573 |
-
masks = TRK_PROCESSOR.post_process_masks(
|
| 574 |
-
outputs.pred_masks.cpu(),
|
| 575 |
-
inputs["original_sizes"],
|
| 576 |
-
binarize=True
|
| 577 |
-
)[0]
|
| 578 |
-
|
| 579 |
-
result = apply_mask_overlay(img, masks[0])
|
| 580 |
-
result = draw_points_on_image(result, job['points'])
|
| 581 |
-
|
| 582 |
-
out_path = os.path.join(OUTPUTS_DIR, f"{job['id']}_result.jpg")
|
| 583 |
-
result.save(out_path)
|
| 584 |
-
|
| 585 |
-
return {
|
| 586 |
-
'image': result,
|
| 587 |
-
'output_path': out_path,
|
| 588 |
-
'duration': f"{(datetime.now() - start).total_seconds():.2f}s"
|
| 589 |
-
}
|
| 590 |
|
| 591 |
-
# ============ BACKGROUND WORKER ============
|
| 592 |
def background_worker():
|
|
|
|
| 593 |
while True:
|
| 594 |
try:
|
| 595 |
job = processing_queue.get()
|
|
@@ -599,8 +224,6 @@ def background_worker():
|
|
| 599 |
job_id = job['id']
|
| 600 |
job_type = job['type']
|
| 601 |
|
| 602 |
-
print(f"🚀 Starting job {job_id[:8]} - Type: {job_type}")
|
| 603 |
-
|
| 604 |
processing_results[job_id] = {'status': 'processing', 'progress': 0}
|
| 605 |
|
| 606 |
try:
|
|
@@ -617,22 +240,17 @@ def background_worker():
|
|
| 617 |
'progress': 100
|
| 618 |
}
|
| 619 |
|
| 620 |
-
|
| 621 |
-
|
| 622 |
save_history({
|
| 623 |
'id': job_id,
|
| 624 |
'type': job_type,
|
| 625 |
'prompt': job.get('prompt', 'N/A'),
|
| 626 |
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
| 627 |
'status': 'completed',
|
| 628 |
-
|
| 629 |
})
|
| 630 |
|
| 631 |
except Exception as e:
|
| 632 |
-
print(f"❌ Job {job_id[:8]} failed: {str(e)}")
|
| 633 |
-
import traceback
|
| 634 |
-
traceback.print_exc()
|
| 635 |
-
|
| 636 |
processing_results[job_id] = {
|
| 637 |
'status': 'error',
|
| 638 |
'error': str(e),
|
|
@@ -647,513 +265,360 @@ def background_worker():
|
|
| 647 |
'error': str(e)
|
| 648 |
})
|
| 649 |
except Exception as e:
|
| 650 |
-
print(f"
|
| 651 |
-
|
| 652 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
|
| 654 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 655 |
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
"""
|
| 664 |
|
| 665 |
-
with gr.Blocks(
|
| 666 |
with gr.Column(elem_id="col-container"):
|
| 667 |
gr.Markdown("# **SAM3: Segment Anything Model 3** 🚀", elem_id="main-title")
|
| 668 |
-
gr.Markdown("
|
| 669 |
-
|
| 670 |
-
gr.Markdown("""
|
| 671 |
-
<div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; border-radius: 10px; color: white; margin-bottom: 20px;'>
|
| 672 |
-
<strong>🔥 Đặc điểm CPU Mode:</strong><br>
|
| 673 |
-
✅ Không bị timeout - xử lý video dài thoải mái<br>
|
| 674 |
-
⏱️ Chậm hơn nhưng ổn định - tốc độ ~2-3 phút/frame<br>
|
| 675 |
-
🔋 Chạy background - submit job và làm việc khác<br>
|
| 676 |
-
💾 Tự động lưu lịch sử và download được
|
| 677 |
-
</div>
|
| 678 |
-
""")
|
| 679 |
-
|
| 680 |
with gr.Tabs():
|
| 681 |
-
# ===== IMAGE TAB =====
|
| 682 |
with gr.Tab("📷 Image Segmentation"):
|
| 683 |
with gr.Row():
|
| 684 |
with gr.Column(scale=1):
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
)
|
| 691 |
-
gr.
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
2. Click vào đối tượng bạn muốn phân đoạn
|
| 695 |
-
3. Kết quả hiển thị ngay lập tức
|
| 696 |
-
4. Click "Clear" để reset và bắt đầu lại
|
| 697 |
-
""")
|
| 698 |
-
|
| 699 |
-
click_clear = gr.Button("🔄 Clear Points & Reset", variant="primary")
|
| 700 |
-
|
| 701 |
-
click_pts = gr.State([])
|
| 702 |
-
click_lbl = gr.State([])
|
| 703 |
-
|
| 704 |
-
with gr.Column(scale=1):
|
| 705 |
-
img_input = gr.Image(label="📤 Upload Image", type="pil", height=350)
|
| 706 |
-
img_prompt = gr.Textbox(
|
| 707 |
-
label="✍️ Text Prompt",
|
| 708 |
-
placeholder="e.g., cat, person, car, building...",
|
| 709 |
-
lines=2
|
| 710 |
-
)
|
| 711 |
-
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 712 |
-
img_conf = gr.Slider(0.0, 1.0, 0.45, 0.05, label="Confidence Threshold")
|
| 713 |
-
|
| 714 |
-
img_submit = gr.Button("🚀 Submit Job (Background)", variant="primary", size="lg")
|
| 715 |
-
img_check = gr.Button("🔍 Check Status", variant="secondary")
|
| 716 |
-
img_job_id = gr.Textbox(label="Job ID", visible=False)
|
| 717 |
-
|
| 718 |
with gr.Column(scale=1.5):
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
with gr.Accordion("📦 Extracted Objects", open=True):
|
| 723 |
-
gr.Markdown("**Các đối tượng được tách ra (PNG với nền trong suốt):**")
|
| 724 |
-
img_gallery = gr.Gallery(
|
| 725 |
-
label="Segmented Objects",
|
| 726 |
-
columns=3,
|
| 727 |
-
height=300,
|
| 728 |
-
object_fit="contain"
|
| 729 |
-
)
|
| 730 |
-
|
| 731 |
-
def submit_img(img, prompt, conf):
|
| 732 |
-
if not img or not prompt:
|
| 733 |
-
return None, "❌ Vui lòng cung cấp ảnh và prompt", "", []
|
| 734 |
-
jid = str(uuid.uuid4())
|
| 735 |
-
processing_queue.put({
|
| 736 |
-
'id': jid,
|
| 737 |
-
'type': 'image',
|
| 738 |
-
'image': img,
|
| 739 |
-
'prompt': prompt,
|
| 740 |
-
'conf_thresh': conf
|
| 741 |
-
})
|
| 742 |
-
return None, f"✅ Đã thêm vào hàng chờ (ID: {jid[:8]}). Đang xử lý...", jid, []
|
| 743 |
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
r = processing_results[jid]
|
| 749 |
-
|
| 750 |
-
if r['status'] == 'processing':
|
| 751 |
-
return None, f"⏳ Đang xử lý... {r['progress']}%", []
|
| 752 |
-
elif r['status'] == 'completed':
|
| 753 |
-
res = r['result']
|
| 754 |
-
gal = [f for f in res.get('segmented_files', []) if os.path.exists(f)]
|
| 755 |
-
status = f"✅ Hoàn thành! Đã tách được {len(gal)} đối tượng | Thời gian: {res.get('duration', 'N/A')}"
|
| 756 |
-
return res['image'], status, gal
|
| 757 |
-
else:
|
| 758 |
-
return None, f"❌ Lỗi: {r.get('error', 'Unknown')}", []
|
| 759 |
-
|
| 760 |
-
img_submit.click(
|
| 761 |
-
fn=submit_img,
|
| 762 |
-
inputs=[img_input, img_prompt, img_conf],
|
| 763 |
-
outputs=[img_result, img_status, img_job_id, img_gallery]
|
| 764 |
)
|
| 765 |
|
| 766 |
-
|
| 767 |
-
fn=
|
| 768 |
-
inputs=[
|
| 769 |
-
outputs=[
|
| 770 |
)
|
| 771 |
|
| 772 |
-
# ===== VIDEO TAB =====
|
| 773 |
with gr.Tab("🎥 Video Segmentation"):
|
| 774 |
with gr.Row():
|
| 775 |
with gr.Column():
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
label="✍️ Text Prompt",
|
| 779 |
-
placeholder="e.g., person running, red car, dog...",
|
| 780 |
-
lines=2
|
| 781 |
-
)
|
| 782 |
-
|
| 783 |
-
with gr.Accordion("⚙️ Settings", open=True):
|
| 784 |
-
vid_frames = gr.Slider(
|
| 785 |
-
10, 500, 60, 10,
|
| 786 |
-
label="Max Frames",
|
| 787 |
-
info="CPU mode: Có thể xử lý nhiều frames hơn, nhưng sẽ chậm hơn"
|
| 788 |
-
)
|
| 789 |
-
|
| 790 |
-
gr.Markdown("""
|
| 791 |
-
**💻 CPU Processing Mode:**
|
| 792 |
-
- ✅ **Không bị timeout** - xử lý bao nhiêu cũng được
|
| 793 |
-
- ⏱️ **Chậm hơn GPU** - khoảng 2-3 phút/frame
|
| 794 |
-
- 🔋 **Ổn định** - không crash, chạy nền background
|
| 795 |
-
|
| 796 |
-
**⏱️ Thời gian ước tính:**
|
| 797 |
-
- 30 frames: ~60-90 phút
|
| 798 |
-
- 60 frames: ~2-3 giờ
|
| 799 |
-
- 100 frames: ~3-5 giờ
|
| 800 |
-
|
| 801 |
-
**💡 Khuyến nghị:**
|
| 802 |
-
- Submit job và làm việc khác
|
| 803 |
-
- Nhấn "Check Status" để xem tiến độ
|
| 804 |
-
- Video sẽ được lưu khi hoàn thành
|
| 805 |
-
""")
|
| 806 |
-
|
| 807 |
-
vid_submit = gr.Button("🚀 Submit Job (Background)", variant="primary", size="lg")
|
| 808 |
-
vid_check = gr.Button("🔍 Check Status", variant="secondary")
|
| 809 |
-
vid_job_id = gr.Textbox(label="Job ID", visible=False)
|
| 810 |
-
|
| 811 |
-
with gr.Column():
|
| 812 |
-
gr.Markdown("### 📹 Video Outputs (3 versions)")
|
| 813 |
-
|
| 814 |
-
with gr.Tabs():
|
| 815 |
-
with gr.Tab("1️⃣ Overlay"):
|
| 816 |
-
vid_overlay = gr.Video(label="Original + Color Masks")
|
| 817 |
-
gr.Markdown("*Video gốc với màu mask phủ lên*")
|
| 818 |
-
|
| 819 |
-
with gr.Tab("2️⃣ Masks Only"):
|
| 820 |
-
vid_masks = gr.Video(label="White Masks on Black")
|
| 821 |
-
gr.Markdown("*Chỉ hiển thị mask màu trắng trên nền đen*")
|
| 822 |
-
|
| 823 |
-
with gr.Tab("3️⃣ Segmented"):
|
| 824 |
-
vid_segmented = gr.Video(label="Green Screen Background")
|
| 825 |
-
gr.Markdown("*Đối tượng với nền xanh lá (green screen)*")
|
| 826 |
-
|
| 827 |
-
vid_status = gr.Textbox(label="📊 Status", interactive=False)
|
| 828 |
-
|
| 829 |
-
def submit_vid(vid, prompt, frames):
|
| 830 |
-
if not vid or not prompt:
|
| 831 |
-
return None, None, None, "❌ Vui lòng cung cấp video và prompt", ""
|
| 832 |
-
jid = str(uuid.uuid4())
|
| 833 |
-
processing_queue.put({
|
| 834 |
-
'id': jid,
|
| 835 |
-
'type': 'video',
|
| 836 |
-
'video': vid,
|
| 837 |
-
'prompt': prompt,
|
| 838 |
-
'frame_limit': frames
|
| 839 |
-
})
|
| 840 |
-
return None, None, None, f"✅ Đã thêm vào hàng chờ (ID: {jid[:8]}). Đang xử lý...", jid
|
| 841 |
-
|
| 842 |
-
def check_vid(jid):
|
| 843 |
-
if not jid or jid not in processing_results:
|
| 844 |
-
return None, None, None, "❌ Không tìm thấy công việc"
|
| 845 |
-
|
| 846 |
-
r = processing_results[jid]
|
| 847 |
-
|
| 848 |
-
if r['status'] == 'processing':
|
| 849 |
-
return None, None, None, f"⏳ Đang xử lý... {r['progress']}%"
|
| 850 |
-
elif r['status'] == 'completed':
|
| 851 |
-
res = r['result']
|
| 852 |
-
status = f"""✅ Hoàn thành! Thời gian: {res.get('duration', 'N/A')}
|
| 853 |
-
|
| 854 |
-
📹 3 video đã được tạo:
|
| 855 |
-
• Overlay - Ảnh gốc với mask màu
|
| 856 |
-
• Masks Only - Chỉ mask (trắng/đen)
|
| 857 |
-
• Segmented - Đối tượng với green screen"""
|
| 858 |
-
return (
|
| 859 |
-
res.get('output_path'),
|
| 860 |
-
res.get('mask_video_path'),
|
| 861 |
-
res.get('segmented_video_path'),
|
| 862 |
-
status
|
| 863 |
-
)
|
| 864 |
-
else:
|
| 865 |
-
return None, None, None, f"❌ Lỗi: {r.get('error', 'Unknown')}"
|
| 866 |
-
|
| 867 |
-
vid_submit.click(
|
| 868 |
-
fn=submit_vid,
|
| 869 |
-
inputs=[vid_input, vid_prompt, vid_frames],
|
| 870 |
-
outputs=[vid_overlay, vid_masks, vid_segmented, vid_status, vid_job_id]
|
| 871 |
-
)
|
| 872 |
-
|
| 873 |
-
vid_check.click(
|
| 874 |
-
fn=check_vid,
|
| 875 |
-
inputs=[vid_job_id],
|
| 876 |
-
outputs=[vid_overlay, vid_masks, vid_segmented, vid_status]
|
| 877 |
-
)
|
| 878 |
-
|
| 879 |
-
# ===== CLICK TAB =====
|
| 880 |
-
with gr.Tab("👆 Click Segmentation"):
|
| 881 |
-
with gr.Row():
|
| 882 |
-
with gr.Column(scale=1):
|
| 883 |
-
click_input = gr.Image(
|
| 884 |
-
type="pil",
|
| 885 |
-
label="📤 Upload Image & Click Objects",
|
| 886 |
-
interactive=True,
|
| 887 |
-
height=450
|
| 888 |
-
)
|
| 889 |
-
gr.Markdown("""
|
| 890 |
-
**📝 Hướng dẫn:**
|
| 891 |
-
1. Upload ảnh
|
| 892 |
-
2. Click vào đối tượng bạn muốn phân đoạn
|
| 893 |
-
3. Kết quả hiển thị ngay lập tức
|
| 894 |
-
4. Click "Clear" để reset và bắt đầu lại
|
| 895 |
-
""")
|
| 896 |
|
| 897 |
-
|
|
|
|
|
|
|
| 898 |
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
with gr.Column(scale=1):
|
| 903 |
-
click_output = gr.Image(
|
| 904 |
-
type="pil",
|
| 905 |
-
label="🎨 Result Preview",
|
| 906 |
-
height=450,
|
| 907 |
-
interactive=False
|
| 908 |
-
)
|
| 909 |
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
if pts is None:
|
| 919 |
-
pts = []
|
| 920 |
-
if lbl is None:
|
| 921 |
-
lbl = []
|
| 922 |
-
|
| 923 |
-
pts.append([evt.index[0], evt.index[1]])
|
| 924 |
-
lbl.append(1)
|
| 925 |
-
|
| 926 |
-
jid = str(uuid.uuid4())
|
| 927 |
-
try:
|
| 928 |
-
res = process_click_job({
|
| 929 |
-
'id': jid,
|
| 930 |
-
'type': 'click',
|
| 931 |
-
'image': img,
|
| 932 |
-
'points': pts,
|
| 933 |
-
'labels': lbl
|
| 934 |
-
})
|
| 935 |
-
return res['image'], pts, lbl
|
| 936 |
-
except Exception as e:
|
| 937 |
-
print(f"Click error: {e}")
|
| 938 |
-
return img, pts, lbl
|
| 939 |
-
|
| 940 |
-
click_input.select(
|
| 941 |
-
fn=on_click,
|
| 942 |
-
inputs=[click_input, click_pts, click_lbl],
|
| 943 |
-
outputs=[click_output, click_pts, click_lbl]
|
| 944 |
)
|
| 945 |
|
| 946 |
-
|
| 947 |
-
fn=
|
| 948 |
-
|
|
|
|
| 949 |
)
|
| 950 |
-
|
| 951 |
-
# ===== DOWNLOAD TAB =====
|
| 952 |
-
with gr.Tab("📥 Download Results"):
|
| 953 |
-
gr.Markdown("""
|
| 954 |
-
# 📦 Download Center
|
| 955 |
-
### Tải về kết quả đã xử lý dưới dạng ZIP
|
| 956 |
-
""")
|
| 957 |
|
|
|
|
|
|
|
| 958 |
with gr.Row():
|
| 959 |
with gr.Column(scale=1):
|
| 960 |
-
gr.
|
| 961 |
-
|
| 962 |
-
download_dropdown = gr.Dropdown(
|
| 963 |
-
label="Chọn công việc đã hoàn thành",
|
| 964 |
-
choices=get_downloadable_jobs(),
|
| 965 |
-
interactive=True,
|
| 966 |
-
scale=1
|
| 967 |
-
)
|
| 968 |
|
| 969 |
with gr.Row():
|
| 970 |
-
|
| 971 |
-
download_btn = gr.Button("📥 Download ZIP", variant="primary", size="lg", scale=2)
|
| 972 |
|
| 973 |
-
|
| 974 |
-
|
|
|
|
| 975 |
with gr.Column(scale=1):
|
| 976 |
-
gr.
|
| 977 |
-
download_file = gr.File(label="Your ZIP file will appear here")
|
| 978 |
-
|
| 979 |
-
gr.Markdown("""
|
| 980 |
-
**📦 Package Contents:**
|
| 981 |
-
|
| 982 |
-
**Image Jobs:**
|
| 983 |
-
- `overlay.jpg` - Ảnh với mask màu
|
| 984 |
-
- `objects/object_*.png` - Từng đối tượng riêng lẻ (PNG transparent)
|
| 985 |
-
- `metadata.json` - Thông tin chi tiết
|
| 986 |
-
|
| 987 |
-
**Video Jobs:**
|
| 988 |
-
- `overlay_video.mp4` - Video với mask màu
|
| 989 |
-
- `masks_only.mp4` - Chỉ mask trắng/đen
|
| 990 |
-
- `segmented_video.mp4` - Video với green screen
|
| 991 |
-
- `metadata.json` - Thông tin chi tiết
|
| 992 |
-
|
| 993 |
-
**Click Jobs:**
|
| 994 |
-
- `result.jpg` - Ảnh kết quả
|
| 995 |
-
- `metadata.json` - Thông tin chi tiết
|
| 996 |
-
""")
|
| 997 |
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
zip_path = create_download_package(job_id)
|
| 1003 |
-
if zip_path and os.path.exists(zip_path):
|
| 1004 |
-
size_mb = os.path.getsize(zip_path) / 1024 / 1024
|
| 1005 |
-
return zip_path, f"✅ Sẵn sàng tải về! Kích thước: {size_mb:.2f} MB"
|
| 1006 |
-
|
| 1007 |
-
return None, "❌ Không thể tạo package. Job có thể đã bị xóa."
|
| 1008 |
-
|
| 1009 |
-
download_refresh.click(
|
| 1010 |
-
fn=lambda: gr.Dropdown(choices=get_downloadable_jobs()),
|
| 1011 |
-
outputs=[download_dropdown]
|
| 1012 |
)
|
| 1013 |
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
-
|
| 1017 |
-
outputs=[download_file, download_status]
|
| 1018 |
)
|
| 1019 |
|
| 1020 |
-
# ===== HISTORY
|
| 1021 |
-
with gr.Tab("
|
| 1022 |
-
with gr.Row():
|
| 1023 |
-
with gr.Column(scale=1):
|
| 1024 |
-
gr.Markdown("### 📈 Statistics Dashboard")
|
| 1025 |
-
|
| 1026 |
-
def update_stats():
|
| 1027 |
-
stats = get_history_stats()
|
| 1028 |
-
return (
|
| 1029 |
-
f"**{stats['total']}**\n\nTổng số jobs",
|
| 1030 |
-
f"**{stats['completed']}**\n\nHoàn thành",
|
| 1031 |
-
f"**{stats['errors']}**\n\nLỗi",
|
| 1032 |
-
f"**{stats['success_rate']}**\n\nTỷ lệ thành công"
|
| 1033 |
-
)
|
| 1034 |
-
|
| 1035 |
-
with gr.Row():
|
| 1036 |
-
stat_total = gr.Markdown("**0**\n\nTổng số jobs", elem_classes=["stat-card"])
|
| 1037 |
-
stat_completed = gr.Markdown("**0**\n\nHoàn thành", elem_classes=["stat-card"])
|
| 1038 |
-
|
| 1039 |
-
with gr.Row():
|
| 1040 |
-
stat_errors = gr.Markdown("**0**\n\nLỗi", elem_classes=["stat-card"])
|
| 1041 |
-
stat_success = gr.Markdown("**0%**\n\nTỷ lệ thành công", elem_classes=["stat-card"])
|
| 1042 |
-
|
| 1043 |
-
gr.Markdown("### 🎯 Quick Actions")
|
| 1044 |
-
|
| 1045 |
-
with gr.Row():
|
| 1046 |
-
btn_refresh = gr.Button("🔄 Refresh All", variant="primary")
|
| 1047 |
-
btn_export = gr.Button("📥 Export JSON", variant="secondary")
|
| 1048 |
-
|
| 1049 |
-
btn_clear_all = gr.Button("🗑️ Clear All History", variant="stop")
|
| 1050 |
-
|
| 1051 |
-
export_file = gr.File(label="Exported File", visible=False)
|
| 1052 |
-
clear_status = gr.Textbox(label="Status", interactive=False)
|
| 1053 |
-
|
| 1054 |
with gr.Row():
|
| 1055 |
with gr.Column():
|
| 1056 |
-
gr.
|
| 1057 |
-
|
| 1058 |
-
with gr.Row():
|
| 1059 |
-
search_input = gr.Textbox(
|
| 1060 |
-
placeholder="🔍 Tìm kiếm theo prompt...",
|
| 1061 |
-
label="Search",
|
| 1062 |
-
scale=2
|
| 1063 |
-
)
|
| 1064 |
-
filter_type = gr.Dropdown(
|
| 1065 |
-
choices=["all", "image", "video", "click"],
|
| 1066 |
-
value="all",
|
| 1067 |
-
label="Loại",
|
| 1068 |
-
scale=1
|
| 1069 |
-
)
|
| 1070 |
-
filter_status = gr.Dropdown(
|
| 1071 |
-
choices=["all", "completed", "error"],
|
| 1072 |
-
value="all",
|
| 1073 |
-
label="Trạng thái",
|
| 1074 |
-
scale=1
|
| 1075 |
-
)
|
| 1076 |
|
| 1077 |
-
|
| 1078 |
-
|
| 1079 |
-
|
| 1080 |
-
|
| 1081 |
-
|
| 1082 |
-
|
| 1083 |
-
|
| 1084 |
-
|
| 1085 |
-
columns=4,
|
| 1086 |
-
height=400,
|
| 1087 |
-
object_fit="contain"
|
| 1088 |
-
)
|
| 1089 |
-
|
| 1090 |
-
def refresh_all():
|
| 1091 |
-
return (
|
| 1092 |
-
*update_stats(),
|
| 1093 |
-
format_history_table(),
|
| 1094 |
-
get_history_gallery()
|
| 1095 |
-
)
|
| 1096 |
-
|
| 1097 |
-
btn_refresh.click(
|
| 1098 |
-
fn=refresh_all,
|
| 1099 |
-
outputs=[stat_total, stat_completed, stat_errors, stat_success, history_table, history_gallery]
|
| 1100 |
-
)
|
| 1101 |
-
|
| 1102 |
-
btn_export.click(
|
| 1103 |
-
fn=export_history_json,
|
| 1104 |
-
outputs=[export_file]
|
| 1105 |
-
)
|
| 1106 |
-
|
| 1107 |
-
btn_clear_all.click(
|
| 1108 |
-
fn=clear_all_history,
|
| 1109 |
-
outputs=[clear_status]
|
| 1110 |
-
).then(
|
| 1111 |
-
fn=refresh_all,
|
| 1112 |
-
outputs=[stat_total, stat_completed, stat_errors, stat_success, history_table, history_gallery]
|
| 1113 |
-
)
|
| 1114 |
-
|
| 1115 |
-
def filter_and_display(keyword, ftype, fstatus):
|
| 1116 |
-
filtered = search_history(keyword, ftype, fstatus)
|
| 1117 |
-
if not filtered:
|
| 1118 |
-
return "<p style='text-align:center; color:#666; padding:40px;'>🔍 Không tìm thấy kết quả phù hợp</p>"
|
| 1119 |
-
return format_history_table()
|
| 1120 |
-
|
| 1121 |
-
search_input.change(
|
| 1122 |
-
fn=filter_and_display,
|
| 1123 |
-
inputs=[search_input, filter_type, filter_status],
|
| 1124 |
-
outputs=[history_table]
|
| 1125 |
-
)
|
| 1126 |
-
|
| 1127 |
-
filter_type.change(
|
| 1128 |
-
fn=filter_and_display,
|
| 1129 |
-
inputs=[search_input, filter_type, filter_status],
|
| 1130 |
-
outputs=[history_table]
|
| 1131 |
-
)
|
| 1132 |
|
| 1133 |
-
|
| 1134 |
-
fn=
|
| 1135 |
-
|
| 1136 |
-
outputs=[history_table]
|
| 1137 |
)
|
| 1138 |
-
|
| 1139 |
-
|
| 1140 |
-
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1144 |
|
| 1145 |
if __name__ == "__main__":
|
| 1146 |
-
print("🚀 Starting SAM3 Application...")
|
| 1147 |
-
print(f"📁 Output directory: {OUTPUTS_DIR}")
|
| 1148 |
-
print(f"📥 Downloads directory: {DOWNLOADS_DIR}")
|
| 1149 |
-
print(f"📊 History file: {HISTORY_FILE}")
|
| 1150 |
-
|
| 1151 |
demo.launch(
|
| 1152 |
-
server_name="0.0.0.0",
|
| 1153 |
-
server_port=7860,
|
| 1154 |
-
max_threads=10,
|
| 1155 |
-
show_error=True,
|
| 1156 |
-
share=False,
|
| 1157 |
css=custom_css,
|
| 1158 |
-
theme=app_theme
|
|
|
|
|
|
|
|
|
|
| 1159 |
)
|
|
|
|
| 1 |
import os
|
| 2 |
import cv2
|
| 3 |
+
import tempfile
|
| 4 |
import spaces
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
import matplotlib
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
from PIL import Image, ImageDraw
|
| 11 |
from typing import Iterable
|
| 12 |
from gradio.themes import Soft
|
|
|
|
| 21 |
import threading
|
| 22 |
import queue
|
| 23 |
import uuid
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# ============ THEME SETUP ============
|
| 26 |
colors.steel_blue = colors.Color(
|
| 27 |
name="steel_blue",
|
| 28 |
+
c50="#EBF3F8",
|
| 29 |
+
c100="#D3E5F0",
|
| 30 |
+
c200="#A8CCE1",
|
| 31 |
+
c300="#7DB3D2",
|
| 32 |
+
c400="#529AC3",
|
| 33 |
+
c500="#4682B4",
|
| 34 |
+
c600="#3E72A0",
|
| 35 |
+
c700="#36638C",
|
| 36 |
+
c800="#2E5378",
|
| 37 |
+
c900="#264364",
|
| 38 |
+
c950="#1E3450",
|
| 39 |
)
|
| 40 |
|
| 41 |
class CustomBlueTheme(Soft):
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
*,
|
| 45 |
+
primary_hue: colors.Color | str = colors.gray,
|
| 46 |
+
secondary_hue: colors.Color | str = colors.steel_blue,
|
| 47 |
+
neutral_hue: colors.Color | str = colors.slate,
|
| 48 |
+
text_size: sizes.Size | str = sizes.text_lg,
|
| 49 |
+
font: fonts.Font | str | Iterable[fonts.Font | str] = (
|
| 50 |
+
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
|
| 51 |
+
),
|
| 52 |
+
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
|
| 53 |
+
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
|
| 54 |
+
),
|
| 55 |
+
):
|
| 56 |
+
super().__init__(
|
| 57 |
+
primary_hue=primary_hue,
|
| 58 |
+
secondary_hue=secondary_hue,
|
| 59 |
+
neutral_hue=neutral_hue,
|
| 60 |
+
text_size=text_size,
|
| 61 |
+
font=font,
|
| 62 |
+
font_mono=font_mono,
|
| 63 |
+
)
|
| 64 |
super().set(
|
| 65 |
background_fill_primary="*primary_50",
|
| 66 |
background_fill_primary_dark="*primary_900",
|
|
|
|
| 86 |
app_theme = CustomBlueTheme()
|
| 87 |
|
| 88 |
# ============ GLOBAL SETUP ============
|
| 89 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 90 |
+
print(f"🖥️ Using compute device: {device}")
|
| 91 |
|
| 92 |
+
# History storage
|
| 93 |
HISTORY_DIR = "processing_history"
|
| 94 |
+
os.makedirs(HISTORY_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
| 95 |
HISTORY_FILE = os.path.join(HISTORY_DIR, "history.json")
|
| 96 |
|
| 97 |
+
# Background processing queue
|
| 98 |
processing_queue = queue.Queue()
|
| 99 |
processing_results = {}
|
| 100 |
|
| 101 |
# Load models
|
| 102 |
+
print("⏳ Loading SAM3 Models permanently into memory...")
|
| 103 |
try:
|
| 104 |
+
print(" ... Loading Image Text Model")
|
| 105 |
IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3").to(device)
|
| 106 |
IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3")
|
| 107 |
+
|
| 108 |
+
print(" ... Loading Image Tracker Model")
|
| 109 |
TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3").to(device)
|
| 110 |
TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3")
|
| 111 |
+
|
| 112 |
+
print(" ... Loading Video Model")
|
| 113 |
+
VID_MODEL = Sam3VideoModel.from_pretrained("DiffusionWave/sam3").to(device, dtype=torch.bfloat16)
|
| 114 |
VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("DiffusionWave/sam3")
|
| 115 |
|
| 116 |
+
print("✅ All Models loaded successfully!")
|
| 117 |
except Exception as e:
|
| 118 |
+
print(f"❌ CRITICAL ERROR LOADING MODELS: {e}")
|
| 119 |
IMG_MODEL = IMG_PROCESSOR = TRK_MODEL = TRK_PROCESSOR = VID_MODEL = VID_PROCESSOR = None
|
| 120 |
|
| 121 |
# ============ HISTORY MANAGEMENT ============
|
| 122 |
def load_history():
|
| 123 |
+
"""Load processing history from JSON file"""
|
| 124 |
if os.path.exists(HISTORY_FILE):
|
| 125 |
try:
|
| 126 |
+
with open(HISTORY_FILE, 'r') as f:
|
| 127 |
return json.load(f)
|
| 128 |
except:
|
| 129 |
return []
|
| 130 |
return []
|
| 131 |
|
| 132 |
+
def save_history(history_item):
|
| 133 |
+
"""Save a new history item"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
history = load_history()
|
| 135 |
+
history.insert(0, history_item) # Add to beginning
|
| 136 |
+
history = history[:100] # Keep last 100 items
|
| 137 |
+
with open(HISTORY_FILE, 'w') as f:
|
| 138 |
+
json.dump(history, f, indent=2)
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
+
def get_history_display():
|
| 141 |
+
"""Format history for display"""
|
| 142 |
history = load_history()
|
| 143 |
if not history:
|
| 144 |
+
return "Chưa có lịch sử xử lý nào"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
display_text = ""
|
| 147 |
+
for i, item in enumerate(history[:50], 1):
|
| 148 |
+
status_emoji = "✅" if item['status'] == 'completed' else "❌"
|
| 149 |
+
display_text += f"{status_emoji} **{item['type'].upper()}** - {item['timestamp']}\n"
|
| 150 |
+
display_text += f" Prompt: {item['prompt']}\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
if item.get('output_path'):
|
| 152 |
+
display_text += f" File: `{os.path.basename(item['output_path'])}`\n"
|
| 153 |
+
display_text += "\n"
|
| 154 |
+
return display_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
+
# ============ UTILITY FUNCTIONS ============
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
def apply_mask_overlay(base_image, mask_data, opacity=0.5):
|
| 158 |
+
"""Draws segmentation masks on top of an image."""
|
| 159 |
if isinstance(base_image, np.ndarray):
|
| 160 |
base_image = Image.fromarray(base_image)
|
| 161 |
base_image = base_image.convert("RGBA")
|
|
|
|
| 167 |
mask_data = mask_data.cpu().numpy()
|
| 168 |
mask_data = mask_data.astype(np.uint8)
|
| 169 |
|
| 170 |
+
if mask_data.ndim == 4: mask_data = mask_data[0]
|
| 171 |
+
if mask_data.ndim == 3 and mask_data.shape[0] == 1: mask_data = mask_data[0]
|
|
|
|
|
|
|
| 172 |
|
| 173 |
num_masks = mask_data.shape[0] if mask_data.ndim == 3 else 1
|
| 174 |
if mask_data.ndim == 2:
|
|
|
|
| 177 |
|
| 178 |
try:
|
| 179 |
color_map = matplotlib.colormaps["rainbow"].resampled(max(num_masks, 1))
|
| 180 |
+
except AttributeError:
|
| 181 |
import matplotlib.cm as cm
|
| 182 |
color_map = cm.get_cmap("rainbow").resampled(max(num_masks, 1))
|
| 183 |
|
| 184 |
rgb_colors = [tuple(int(c * 255) for c in color_map(i)[:3]) for i in range(num_masks)]
|
| 185 |
composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0))
|
| 186 |
|
| 187 |
+
for i, single_mask in enumerate(mask_data):
|
| 188 |
+
mask_bitmap = Image.fromarray((single_mask * 255).astype(np.uint8))
|
| 189 |
+
if mask_bitmap.size != base_image.size:
|
| 190 |
+
mask_bitmap = mask_bitmap.resize(base_image.size, resample=Image.NEAREST)
|
| 191 |
|
| 192 |
+
fill_color = rgb_colors[i]
|
| 193 |
+
color_fill = Image.new("RGBA", base_image.size, fill_color + (0,))
|
| 194 |
+
mask_alpha = mask_bitmap.point(lambda v: int(v * opacity) if v > 0 else 0)
|
| 195 |
color_fill.putalpha(mask_alpha)
|
| 196 |
composite_layer = Image.alpha_composite(composite_layer, color_fill)
|
| 197 |
|
| 198 |
return Image.alpha_composite(base_image, composite_layer).convert("RGB")
|
| 199 |
|
| 200 |
def draw_points_on_image(image, points):
|
| 201 |
+
"""Draws red dots on the image to indicate click locations."""
|
| 202 |
if isinstance(image, np.ndarray):
|
| 203 |
image = Image.fromarray(image)
|
| 204 |
+
|
| 205 |
draw_img = image.copy()
|
| 206 |
draw = ImageDraw.Draw(draw_img)
|
| 207 |
+
|
| 208 |
+
for pt in points:
|
| 209 |
+
x, y = pt
|
| 210 |
r = 8
|
| 211 |
draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white", width=4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
+
return draw_img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
+
# ============ BACKGROUND PROCESSING WORKER ============
|
| 216 |
def background_worker():
|
| 217 |
+
"""Background thread that processes jobs from queue"""
|
| 218 |
while True:
|
| 219 |
try:
|
| 220 |
job = processing_queue.get()
|
|
|
|
| 224 |
job_id = job['id']
|
| 225 |
job_type = job['type']
|
| 226 |
|
|
|
|
|
|
|
| 227 |
processing_results[job_id] = {'status': 'processing', 'progress': 0}
|
| 228 |
|
| 229 |
try:
|
|
|
|
| 240 |
'progress': 100
|
| 241 |
}
|
| 242 |
|
| 243 |
+
# Save to history
|
|
|
|
| 244 |
save_history({
|
| 245 |
'id': job_id,
|
| 246 |
'type': job_type,
|
| 247 |
'prompt': job.get('prompt', 'N/A'),
|
| 248 |
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
| 249 |
'status': 'completed',
|
| 250 |
+
'output_path': result.get('output_path')
|
| 251 |
})
|
| 252 |
|
| 253 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
processing_results[job_id] = {
|
| 255 |
'status': 'error',
|
| 256 |
'error': str(e),
|
|
|
|
| 265 |
'error': str(e)
|
| 266 |
})
|
| 267 |
except Exception as e:
|
| 268 |
+
print(f"Worker error: {e}")
|
| 269 |
+
|
| 270 |
+
# Start background worker
|
| 271 |
+
worker_thread = threading.Thread(target=background_worker, daemon=True)
|
| 272 |
+
worker_thread.start()
|
| 273 |
+
|
| 274 |
+
# ============ JOB PROCESSORS ============
|
| 275 |
+
@spaces.GPU
|
| 276 |
+
def process_image_job(job):
|
| 277 |
+
"""Process image segmentation job"""
|
| 278 |
+
source_img = job['image']
|
| 279 |
+
text_query = job['prompt']
|
| 280 |
+
conf_thresh = job.get('conf_thresh', 0.5)
|
| 281 |
+
|
| 282 |
+
if isinstance(source_img, str):
|
| 283 |
+
source_img = Image.open(source_img)
|
| 284 |
+
|
| 285 |
+
pil_image = source_img.convert("RGB")
|
| 286 |
+
model_inputs = IMG_PROCESSOR(images=pil_image, text=text_query, return_tensors="pt").to(device)
|
| 287 |
+
|
| 288 |
+
with torch.no_grad():
|
| 289 |
+
inference_output = IMG_MODEL(**model_inputs)
|
| 290 |
+
|
| 291 |
+
processed_results = IMG_PROCESSOR.post_process_instance_segmentation(
|
| 292 |
+
inference_output,
|
| 293 |
+
threshold=conf_thresh,
|
| 294 |
+
mask_threshold=0.5,
|
| 295 |
+
target_sizes=model_inputs.get("original_sizes").tolist()
|
| 296 |
+
)[0]
|
| 297 |
+
|
| 298 |
+
annotation_list = []
|
| 299 |
+
raw_masks = processed_results['masks'].cpu().numpy()
|
| 300 |
+
raw_scores = processed_results['scores'].cpu().numpy()
|
| 301 |
+
|
| 302 |
+
for idx, mask_array in enumerate(raw_masks):
|
| 303 |
+
label_str = f"{text_query} ({raw_scores[idx]:.2f})"
|
| 304 |
+
annotation_list.append((mask_array, label_str))
|
| 305 |
+
|
| 306 |
+
# Save output
|
| 307 |
+
output_path = os.path.join(HISTORY_DIR, f"{job['id']}_result.jpg")
|
| 308 |
+
result_img = apply_mask_overlay(pil_image, raw_masks)
|
| 309 |
+
result_img.save(output_path)
|
| 310 |
+
|
| 311 |
+
return {
|
| 312 |
+
'image': (pil_image, annotation_list),
|
| 313 |
+
'output_path': output_path
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
@spaces.GPU
|
| 317 |
+
def process_video_job(job):
|
| 318 |
+
"""Process video segmentation job"""
|
| 319 |
+
source_vid = job['video']
|
| 320 |
+
text_query = job['prompt']
|
| 321 |
+
frame_limit = job.get('frame_limit', 60)
|
| 322 |
+
|
| 323 |
+
video_cap = cv2.VideoCapture(source_vid)
|
| 324 |
+
vid_fps = video_cap.get(cv2.CAP_PROP_FPS)
|
| 325 |
+
vid_w = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 326 |
+
vid_h = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 327 |
+
|
| 328 |
+
video_frames = []
|
| 329 |
+
counter = 0
|
| 330 |
+
while video_cap.isOpened():
|
| 331 |
+
ret, frame = video_cap.read()
|
| 332 |
+
if not ret or (frame_limit > 0 and counter >= frame_limit): break
|
| 333 |
+
video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 334 |
+
counter += 1
|
| 335 |
+
video_cap.release()
|
| 336 |
+
|
| 337 |
+
session = VID_PROCESSOR.init_video_session(video=video_frames, inference_device=device, dtype=torch.bfloat16)
|
| 338 |
+
session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=text_query)
|
| 339 |
+
|
| 340 |
+
output_path = os.path.join(HISTORY_DIR, f"{job['id']}_result.mp4")
|
| 341 |
+
video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), vid_fps, (vid_w, vid_h))
|
| 342 |
+
|
| 343 |
+
total_frames = len(video_frames)
|
| 344 |
+
for frame_idx, model_out in enumerate(VID_MODEL.propagate_in_video_iterator(inference_session=session, max_frame_num_to_track=total_frames)):
|
| 345 |
+
post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out)
|
| 346 |
+
f_idx = model_out.frame_idx
|
| 347 |
+
original_pil = Image.fromarray(video_frames[f_idx])
|
| 348 |
+
|
| 349 |
+
if 'masks' in post_processed:
|
| 350 |
+
detected_masks = post_processed['masks']
|
| 351 |
+
if detected_masks.ndim == 4: detected_masks = detected_masks.squeeze(1)
|
| 352 |
+
final_frame = apply_mask_overlay(original_pil, detected_masks)
|
| 353 |
+
else:
|
| 354 |
+
final_frame = original_pil
|
| 355 |
+
|
| 356 |
+
video_writer.write(cv2.cvtColor(np.array(final_frame), cv2.COLOR_RGB2BGR))
|
| 357 |
+
|
| 358 |
+
# Update progress
|
| 359 |
+
progress = int((frame_idx + 1) / total_frames * 100)
|
| 360 |
+
processing_results[job['id']]['progress'] = progress
|
| 361 |
+
|
| 362 |
+
video_writer.release()
|
| 363 |
+
return {'output_path': output_path}
|
| 364 |
+
|
| 365 |
+
@spaces.GPU
|
| 366 |
+
def process_click_job(job):
|
| 367 |
+
"""Process click segmentation job"""
|
| 368 |
+
input_image = job['image']
|
| 369 |
+
points_state = job['points']
|
| 370 |
+
labels_state = job['labels']
|
| 371 |
+
|
| 372 |
+
if isinstance(input_image, str):
|
| 373 |
+
input_image = Image.open(input_image)
|
| 374 |
+
|
| 375 |
+
input_points = [[points_state]]
|
| 376 |
+
input_labels = [[labels_state]]
|
| 377 |
+
|
| 378 |
+
inputs = TRK_PROCESSOR(images=input_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
|
| 379 |
+
|
| 380 |
+
with torch.no_grad():
|
| 381 |
+
outputs = TRK_MODEL(**inputs, multimask_output=False)
|
| 382 |
+
|
| 383 |
+
masks = TRK_PROCESSOR.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"], binarize=True)[0]
|
| 384 |
+
final_img = apply_mask_overlay(input_image, masks[0])
|
| 385 |
+
final_img = draw_points_on_image(final_img, points_state)
|
| 386 |
+
|
| 387 |
+
output_path = os.path.join(HISTORY_DIR, f"{job['id']}_result.jpg")
|
| 388 |
+
final_img.save(output_path)
|
| 389 |
+
|
| 390 |
+
return {
|
| 391 |
+
'image': final_img,
|
| 392 |
+
'output_path': output_path
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
# ============ UI HANDLERS ============
|
| 396 |
+
def submit_image_job(source_img, text_query, conf_thresh):
|
| 397 |
+
"""Submit image segmentation job to background queue"""
|
| 398 |
+
if source_img is None or not text_query:
|
| 399 |
+
return None, "❌ Vui lòng cung cấp ảnh và prompt", ""
|
| 400 |
+
|
| 401 |
+
job_id = str(uuid.uuid4())
|
| 402 |
+
job = {
|
| 403 |
+
'id': job_id,
|
| 404 |
+
'type': 'image',
|
| 405 |
+
'image': source_img,
|
| 406 |
+
'prompt': text_query,
|
| 407 |
+
'conf_thresh': conf_thresh
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
processing_queue.put(job)
|
| 411 |
+
return None, f"✅ Đã thêm vào hàng chờ (ID: {job_id[:8]}). Đang xử lý...", job_id
|
| 412 |
|
| 413 |
+
def check_image_status(job_id):
|
| 414 |
+
"""Check status of image processing job"""
|
| 415 |
+
if not job_id or job_id not in processing_results:
|
| 416 |
+
return None, "Không tìm thấy công việc"
|
| 417 |
+
|
| 418 |
+
result = processing_results[job_id]
|
| 419 |
+
|
| 420 |
+
if result['status'] == 'processing':
|
| 421 |
+
return None, f"⏳ Đang xử lý... {result['progress']}%"
|
| 422 |
+
elif result['status'] == 'completed':
|
| 423 |
+
return result['result']['image'], "✅ Hoàn thành!"
|
| 424 |
+
else:
|
| 425 |
+
return None, f"❌ Lỗi: {result.get('error', 'Unknown')}"
|
| 426 |
+
|
| 427 |
+
def submit_video_job(source_vid, text_query, frame_limit, time_limit):
|
| 428 |
+
"""Submit video segmentation job to background queue"""
|
| 429 |
+
if not source_vid or not text_query:
|
| 430 |
+
return None, "❌ Vui lòng cung cấp video và prompt", ""
|
| 431 |
+
|
| 432 |
+
job_id = str(uuid.uuid4())
|
| 433 |
+
job = {
|
| 434 |
+
'id': job_id,
|
| 435 |
+
'type': 'video',
|
| 436 |
+
'video': source_vid,
|
| 437 |
+
'prompt': text_query,
|
| 438 |
+
'frame_limit': frame_limit,
|
| 439 |
+
'time_limit': time_limit
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
processing_queue.put(job)
|
| 443 |
+
return None, f"✅ Đã thêm vào hàng chờ (ID: {job_id[:8]}). Đang xử lý...", job_id
|
| 444 |
|
| 445 |
+
def check_video_status(job_id):
|
| 446 |
+
"""Check status of video processing job"""
|
| 447 |
+
if not job_id or job_id not in processing_results:
|
| 448 |
+
return None, "Không tìm thấy công việc"
|
| 449 |
+
|
| 450 |
+
result = processing_results[job_id]
|
| 451 |
+
|
| 452 |
+
if result['status'] == 'processing':
|
| 453 |
+
return None, f"⏳ Đang xử lý... {result['progress']}%"
|
| 454 |
+
elif result['status'] == 'completed':
|
| 455 |
+
return result['result']['output_path'], "✅ Hoàn thành!"
|
| 456 |
+
else:
|
| 457 |
+
return None, f"❌ Lỗi: {result.get('error', 'Unknown')}"
|
| 458 |
+
|
| 459 |
+
def image_click_handler(image, evt: gr.SelectData, points_state, labels_state):
|
| 460 |
+
"""Handle click events for interactive segmentation"""
|
| 461 |
+
x, y = evt.index
|
| 462 |
+
|
| 463 |
+
if points_state is None: points_state = []
|
| 464 |
+
if labels_state is None: labels_state = []
|
| 465 |
+
|
| 466 |
+
points_state.append([x, y])
|
| 467 |
+
labels_state.append(1)
|
| 468 |
+
|
| 469 |
+
# Process immediately (can be changed to background if needed)
|
| 470 |
+
job_id = str(uuid.uuid4())
|
| 471 |
+
job = {
|
| 472 |
+
'id': job_id,
|
| 473 |
+
'type': 'click',
|
| 474 |
+
'image': image,
|
| 475 |
+
'points': points_state,
|
| 476 |
+
'labels': labels_state
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
try:
|
| 480 |
+
result = process_click_job(job)
|
| 481 |
+
return result['image'], points_state, labels_state
|
| 482 |
+
except Exception as e:
|
| 483 |
+
print(f"Click error: {e}")
|
| 484 |
+
return image, points_state, labels_state
|
| 485 |
+
|
| 486 |
+
# ============ GRADIO INTERFACE ============
|
| 487 |
+
custom_css="""
|
| 488 |
+
#col-container { margin: 0 auto; max-width: 1200px; }
|
| 489 |
+
#main-title h1 { font-size: 2.1em !important; }
|
| 490 |
+
.history-box { max-height: 600px; overflow-y: auto; }
|
| 491 |
"""
|
| 492 |
|
| 493 |
+
with gr.Blocks(css=custom_css, theme=app_theme) as demo:
|
| 494 |
with gr.Column(elem_id="col-container"):
|
| 495 |
gr.Markdown("# **SAM3: Segment Anything Model 3** 🚀", elem_id="main-title")
|
| 496 |
+
gr.Markdown("Xử lý ảnh/video với **background processing** - không cần chờ đợi!")
|
| 497 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
with gr.Tabs():
|
| 499 |
+
# ===== IMAGE SEGMENTATION TAB =====
|
| 500 |
with gr.Tab("📷 Image Segmentation"):
|
| 501 |
with gr.Row():
|
| 502 |
with gr.Column(scale=1):
|
| 503 |
+
image_input = gr.Image(label="Upload Image", type="pil", height=350)
|
| 504 |
+
txt_prompt_img = gr.Textbox(label="Text Prompt", placeholder="e.g., cat, face, car wheel")
|
| 505 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 506 |
+
conf_slider = gr.Slider(0.0, 1.0, value=0.45, step=0.05, label="Confidence Threshold")
|
| 507 |
+
|
| 508 |
+
btn_submit_img = gr.Button("🚀 Submit Job (Background)", variant="primary")
|
| 509 |
+
btn_check_img = gr.Button("🔍 Check Status", variant="secondary")
|
| 510 |
+
job_id_img = gr.Textbox(label="Job ID", visible=False)
|
| 511 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
with gr.Column(scale=1.5):
|
| 513 |
+
image_result = gr.AnnotatedImage(label="Segmented Result", height=410)
|
| 514 |
+
status_img = gr.Textbox(label="Status", interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
+
btn_submit_img.click(
|
| 517 |
+
fn=submit_image_job,
|
| 518 |
+
inputs=[image_input, txt_prompt_img, conf_slider],
|
| 519 |
+
outputs=[image_result, status_img, job_id_img]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
)
|
| 521 |
|
| 522 |
+
btn_check_img.click(
|
| 523 |
+
fn=check_image_status,
|
| 524 |
+
inputs=[job_id_img],
|
| 525 |
+
outputs=[image_result, status_img]
|
| 526 |
)
|
| 527 |
|
| 528 |
+
# ===== VIDEO SEGMENTATION TAB =====
|
| 529 |
with gr.Tab("🎥 Video Segmentation"):
|
| 530 |
with gr.Row():
|
| 531 |
with gr.Column():
|
| 532 |
+
video_input = gr.Video(label="Upload Video", format="mp4", height=320)
|
| 533 |
+
txt_prompt_vid = gr.Textbox(label="Text Prompt", placeholder="e.g., person running, red car")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
|
| 535 |
+
with gr.Row():
|
| 536 |
+
frame_limiter = gr.Slider(10, 500, value=60, step=10, label="Max Frames")
|
| 537 |
+
time_limiter = gr.Radio([60, 120, 180], value=60, label="Timeout (seconds)")
|
| 538 |
|
| 539 |
+
btn_submit_vid = gr.Button("🚀 Submit Job (Background)", variant="primary")
|
| 540 |
+
btn_check_vid = gr.Button("🔍 Check Status", variant="secondary")
|
| 541 |
+
job_id_vid = gr.Textbox(label="Job ID", visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
|
| 543 |
+
with gr.Column():
|
| 544 |
+
video_result = gr.Video(label="Processed Video")
|
| 545 |
+
status_vid = gr.Textbox(label="Status", interactive=False)
|
| 546 |
+
|
| 547 |
+
btn_submit_vid.click(
|
| 548 |
+
fn=submit_video_job,
|
| 549 |
+
inputs=[video_input, txt_prompt_vid, frame_limiter, time_limiter],
|
| 550 |
+
outputs=[video_result, status_vid, job_id_vid]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
)
|
| 552 |
|
| 553 |
+
btn_check_vid.click(
|
| 554 |
+
fn=check_video_status,
|
| 555 |
+
inputs=[job_id_vid],
|
| 556 |
+
outputs=[video_result, status_vid]
|
| 557 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
|
| 559 |
+
# ===== CLICK SEGMENTATION TAB =====
|
| 560 |
+
with gr.Tab("👆 Click Segmentation"):
|
| 561 |
with gr.Row():
|
| 562 |
with gr.Column(scale=1):
|
| 563 |
+
img_click_input = gr.Image(type="pil", label="Upload Image", interactive=True, height=450)
|
| 564 |
+
gr.Markdown("**Hướng dẫn:** Click vào đối tượng bạn muốn phân đoạn")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
|
| 566 |
with gr.Row():
|
| 567 |
+
img_click_clear = gr.Button("🔄 Clear Points & Reset", variant="primary")
|
|
|
|
| 568 |
|
| 569 |
+
st_click_points = gr.State([])
|
| 570 |
+
st_click_labels = gr.State([])
|
| 571 |
+
|
| 572 |
with gr.Column(scale=1):
|
| 573 |
+
img_click_output = gr.Image(type="pil", label="Result Preview", height=450, interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
|
| 575 |
+
img_click_input.select(
|
| 576 |
+
image_click_handler,
|
| 577 |
+
inputs=[img_click_input, st_click_points, st_click_labels],
|
| 578 |
+
outputs=[img_click_output, st_click_points, st_click_labels]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
)
|
| 580 |
|
| 581 |
+
img_click_clear.click(
|
| 582 |
+
lambda: (None, [], []),
|
| 583 |
+
outputs=[img_click_output, st_click_points, st_click_labels]
|
|
|
|
| 584 |
)
|
| 585 |
|
| 586 |
+
# ===== HISTORY TAB =====
|
| 587 |
+
with gr.Tab("📜 Lịch Sử Xử Lý"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
with gr.Row():
|
| 589 |
with gr.Column():
|
| 590 |
+
btn_refresh_history = gr.Button("🔄 Refresh History", variant="primary")
|
| 591 |
+
history_display = gr.Markdown(value=get_history_display(), elem_classes="history-box")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
|
| 593 |
+
with gr.Accordion("Hướng dẫn", open=False):
|
| 594 |
+
gr.Markdown("""
|
| 595 |
+
### Lịch sử lưu:
|
| 596 |
+
- ✅ **Hoàn thành**: File đã được xử lý thành công
|
| 597 |
+
- ❌ **Lỗi**: Xử lý thất bại
|
| 598 |
+
- Tất cả file output được lưu trong thư mục `processing_history/`
|
| 599 |
+
- Hệ thống giữ lại 100 lịch sử gần nhất
|
| 600 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
|
| 602 |
+
btn_refresh_history.click(
|
| 603 |
+
fn=get_history_display,
|
| 604 |
+
outputs=[history_display]
|
|
|
|
| 605 |
)
|
| 606 |
+
|
| 607 |
+
# ===== BATCH PROCESSING TAB =====
|
| 608 |
+
with gr.Tab("⚙️ Batch Processing"):
|
| 609 |
+
gr.Markdown("### Xử lý hàng loạt (Coming Soon)")
|
| 610 |
+
gr.Markdown("""
|
| 611 |
+
Tính năng này sẽ cho phép bạn:
|
| 612 |
+
- Upload nhiều ảnh/video cùng lúc
|
| 613 |
+
- Tự động xử lý tuần tự
|
| 614 |
+
- Download tất cả kết quả dưới dạng ZIP
|
| 615 |
+
""")
|
| 616 |
|
| 617 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
demo.launch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
css=custom_css,
|
| 620 |
+
theme=app_theme,
|
| 621 |
+
ssr_mode=False,
|
| 622 |
+
mcp_server=True,
|
| 623 |
+
show_error=True
|
| 624 |
)
|