|
|
import os |
|
|
import cv2 |
|
|
import tempfile |
|
|
import spaces |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
import matplotlib |
|
|
import matplotlib.pyplot as plt |
|
|
from PIL import Image, ImageDraw |
|
|
from typing import Iterable |
|
|
from gradio.themes import Soft |
|
|
from gradio.themes.utils import colors, fonts, sizes |
|
|
from transformers import ( |
|
|
Sam3Model, Sam3Processor, |
|
|
Sam3VideoModel, Sam3VideoProcessor, |
|
|
Sam3TrackerModel, Sam3TrackerProcessor |
|
|
) |
|
|
import json |
|
|
from datetime import datetime |
|
|
import threading |
|
|
import queue |
|
|
import uuid |
|
|
|
|
|
|
|
|
colors.steel_blue = colors.Color( |
|
|
name="steel_blue", |
|
|
c50="#EBF3F8", |
|
|
c100="#D3E5F0", |
|
|
c200="#A8CCE1", |
|
|
c300="#7DB3D2", |
|
|
c400="#529AC3", |
|
|
c500="#4682B4", |
|
|
c600="#3E72A0", |
|
|
c700="#36638C", |
|
|
c800="#2E5378", |
|
|
c900="#264364", |
|
|
c950="#1E3450", |
|
|
) |
|
|
|
|
|
class CustomBlueTheme(Soft): |
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
primary_hue: colors.Color | str = colors.gray, |
|
|
secondary_hue: colors.Color | str = colors.steel_blue, |
|
|
neutral_hue: colors.Color | str = colors.slate, |
|
|
text_size: sizes.Size | str = sizes.text_lg, |
|
|
font: fonts.Font | str | Iterable[fonts.Font | str] = ( |
|
|
fonts.GoogleFont("Outfit"), "Arial", "sans-serif", |
|
|
), |
|
|
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( |
|
|
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", |
|
|
), |
|
|
): |
|
|
super().__init__( |
|
|
primary_hue=primary_hue, |
|
|
secondary_hue=secondary_hue, |
|
|
neutral_hue=neutral_hue, |
|
|
text_size=text_size, |
|
|
font=font, |
|
|
font_mono=font_mono, |
|
|
) |
|
|
super().set( |
|
|
background_fill_primary="*primary_50", |
|
|
background_fill_primary_dark="*primary_900", |
|
|
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", |
|
|
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", |
|
|
button_primary_text_color="white", |
|
|
button_primary_text_color_hover="white", |
|
|
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", |
|
|
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", |
|
|
button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)", |
|
|
button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)", |
|
|
slider_color="*secondary_500", |
|
|
slider_color_dark="*secondary_600", |
|
|
block_title_text_weight="600", |
|
|
block_border_width="3px", |
|
|
block_shadow="*shadow_drop_lg", |
|
|
button_primary_shadow="*shadow_drop_lg", |
|
|
button_large_padding="11px", |
|
|
color_accent_soft="*primary_100", |
|
|
block_label_background_fill="*primary_200", |
|
|
) |
|
|
|
|
|
app_theme = CustomBlueTheme() |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"🖥️ Using compute device: {device}") |
|
|
|
|
|
|
|
|
HISTORY_DIR = "processing_history" |
|
|
os.makedirs(HISTORY_DIR, exist_ok=True) |
|
|
HISTORY_FILE = os.path.join(HISTORY_DIR, "history.json") |
|
|
|
|
|
|
|
|
processing_queue = queue.Queue() |
|
|
processing_results = {} |
|
|
|
|
|
|
|
|
print("⏳ Loading SAM3 Models permanently into memory...") |
|
|
try: |
|
|
print(" ... Loading Image Text Model") |
|
|
IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3").to(device) |
|
|
IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3") |
|
|
|
|
|
print(" ... Loading Image Tracker Model") |
|
|
TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3").to(device) |
|
|
TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3") |
|
|
|
|
|
print(" ... Loading Video Model") |
|
|
VID_MODEL = Sam3VideoModel.from_pretrained("DiffusionWave/sam3").to(device, dtype=torch.bfloat16) |
|
|
VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("DiffusionWave/sam3") |
|
|
|
|
|
print("✅ All Models loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"❌ CRITICAL ERROR LOADING MODELS: {e}") |
|
|
IMG_MODEL = IMG_PROCESSOR = TRK_MODEL = TRK_PROCESSOR = VID_MODEL = VID_PROCESSOR = None |
|
|
|
|
|
|
|
|
def load_history(): |
|
|
"""Load processing history from JSON file""" |
|
|
if os.path.exists(HISTORY_FILE): |
|
|
try: |
|
|
with open(HISTORY_FILE, 'r') as f: |
|
|
return json.load(f) |
|
|
except: |
|
|
return [] |
|
|
return [] |
|
|
|
|
|
def save_history(history_item): |
|
|
"""Save a new history item""" |
|
|
history = load_history() |
|
|
history.insert(0, history_item) |
|
|
history = history[:100] |
|
|
with open(HISTORY_FILE, 'w') as f: |
|
|
json.dump(history, f, indent=2) |
|
|
|
|
|
def get_history_display(): |
|
|
"""Format history for display""" |
|
|
history = load_history() |
|
|
if not history: |
|
|
return "Chưa có lịch sử xử lý nào" |
|
|
|
|
|
display_text = "" |
|
|
for i, item in enumerate(history[:50], 1): |
|
|
status_emoji = "✅" if item['status'] == 'completed' else "❌" |
|
|
display_text += f"{status_emoji} **{item['type'].upper()}** - {item['timestamp']}\n" |
|
|
display_text += f" Prompt: {item['prompt']}\n" |
|
|
if item.get('output_path'): |
|
|
display_text += f" File: `{os.path.basename(item['output_path'])}`\n" |
|
|
display_text += "\n" |
|
|
return display_text |
|
|
|
|
|
|
|
|
def apply_mask_overlay(base_image, mask_data, opacity=0.5): |
|
|
"""Draws segmentation masks on top of an image.""" |
|
|
if isinstance(base_image, np.ndarray): |
|
|
base_image = Image.fromarray(base_image) |
|
|
base_image = base_image.convert("RGBA") |
|
|
|
|
|
if mask_data is None or len(mask_data) == 0: |
|
|
return base_image.convert("RGB") |
|
|
|
|
|
if isinstance(mask_data, torch.Tensor): |
|
|
mask_data = mask_data.cpu().numpy() |
|
|
mask_data = mask_data.astype(np.uint8) |
|
|
|
|
|
if mask_data.ndim == 4: mask_data = mask_data[0] |
|
|
if mask_data.ndim == 3 and mask_data.shape[0] == 1: mask_data = mask_data[0] |
|
|
|
|
|
num_masks = mask_data.shape[0] if mask_data.ndim == 3 else 1 |
|
|
if mask_data.ndim == 2: |
|
|
mask_data = [mask_data] |
|
|
num_masks = 1 |
|
|
|
|
|
try: |
|
|
color_map = matplotlib.colormaps["rainbow"].resampled(max(num_masks, 1)) |
|
|
except AttributeError: |
|
|
import matplotlib.cm as cm |
|
|
color_map = cm.get_cmap("rainbow").resampled(max(num_masks, 1)) |
|
|
|
|
|
rgb_colors = [tuple(int(c * 255) for c in color_map(i)[:3]) for i in range(num_masks)] |
|
|
composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0)) |
|
|
|
|
|
for i, single_mask in enumerate(mask_data): |
|
|
mask_bitmap = Image.fromarray((single_mask * 255).astype(np.uint8)) |
|
|
if mask_bitmap.size != base_image.size: |
|
|
mask_bitmap = mask_bitmap.resize(base_image.size, resample=Image.NEAREST) |
|
|
|
|
|
fill_color = rgb_colors[i] |
|
|
color_fill = Image.new("RGBA", base_image.size, fill_color + (0,)) |
|
|
mask_alpha = mask_bitmap.point(lambda v: int(v * opacity) if v > 0 else 0) |
|
|
color_fill.putalpha(mask_alpha) |
|
|
composite_layer = Image.alpha_composite(composite_layer, color_fill) |
|
|
|
|
|
return Image.alpha_composite(base_image, composite_layer).convert("RGB") |
|
|
|
|
|
def draw_points_on_image(image, points): |
|
|
"""Draws red dots on the image to indicate click locations.""" |
|
|
if isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
draw_img = image.copy() |
|
|
draw = ImageDraw.Draw(draw_img) |
|
|
|
|
|
for pt in points: |
|
|
x, y = pt |
|
|
r = 8 |
|
|
draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white", width=4) |
|
|
|
|
|
return draw_img |
|
|
|
|
|
|
|
|
def background_worker(): |
|
|
"""Background thread that processes jobs from queue""" |
|
|
while True: |
|
|
try: |
|
|
job = processing_queue.get() |
|
|
if job is None: |
|
|
break |
|
|
|
|
|
job_id = job['id'] |
|
|
job_type = job['type'] |
|
|
|
|
|
processing_results[job_id] = {'status': 'processing', 'progress': 0} |
|
|
|
|
|
try: |
|
|
if job_type == 'image': |
|
|
result = process_image_job(job) |
|
|
elif job_type == 'video': |
|
|
result = process_video_job(job) |
|
|
elif job_type == 'click': |
|
|
result = process_click_job(job) |
|
|
|
|
|
processing_results[job_id] = { |
|
|
'status': 'completed', |
|
|
'result': result, |
|
|
'progress': 100 |
|
|
} |
|
|
|
|
|
|
|
|
save_history({ |
|
|
'id': job_id, |
|
|
'type': job_type, |
|
|
'prompt': job.get('prompt', 'N/A'), |
|
|
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), |
|
|
'status': 'completed', |
|
|
'output_path': result.get('output_path') |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
processing_results[job_id] = { |
|
|
'status': 'error', |
|
|
'error': str(e), |
|
|
'progress': 0 |
|
|
} |
|
|
save_history({ |
|
|
'id': job_id, |
|
|
'type': job_type, |
|
|
'prompt': job.get('prompt', 'N/A'), |
|
|
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), |
|
|
'status': 'error', |
|
|
'error': str(e) |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"Worker error: {e}") |
|
|
|
|
|
|
|
|
worker_thread = threading.Thread(target=background_worker, daemon=True) |
|
|
worker_thread.start() |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def process_image_job(job): |
|
|
"""Process image segmentation job""" |
|
|
source_img = job['image'] |
|
|
text_query = job['prompt'] |
|
|
conf_thresh = job.get('conf_thresh', 0.5) |
|
|
|
|
|
if isinstance(source_img, str): |
|
|
source_img = Image.open(source_img) |
|
|
|
|
|
pil_image = source_img.convert("RGB") |
|
|
model_inputs = IMG_PROCESSOR(images=pil_image, text=text_query, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
inference_output = IMG_MODEL(**model_inputs) |
|
|
|
|
|
processed_results = IMG_PROCESSOR.post_process_instance_segmentation( |
|
|
inference_output, |
|
|
threshold=conf_thresh, |
|
|
mask_threshold=0.5, |
|
|
target_sizes=model_inputs.get("original_sizes").tolist() |
|
|
)[0] |
|
|
|
|
|
annotation_list = [] |
|
|
raw_masks = processed_results['masks'].cpu().numpy() |
|
|
raw_scores = processed_results['scores'].cpu().numpy() |
|
|
|
|
|
for idx, mask_array in enumerate(raw_masks): |
|
|
label_str = f"{text_query} ({raw_scores[idx]:.2f})" |
|
|
annotation_list.append((mask_array, label_str)) |
|
|
|
|
|
|
|
|
output_path = os.path.join(HISTORY_DIR, f"{job['id']}_result.jpg") |
|
|
result_img = apply_mask_overlay(pil_image, raw_masks) |
|
|
result_img.save(output_path) |
|
|
|
|
|
return { |
|
|
'image': (pil_image, annotation_list), |
|
|
'output_path': output_path |
|
|
} |
|
|
|
|
|
@spaces.GPU |
|
|
def process_video_job(job): |
|
|
"""Process video segmentation job""" |
|
|
source_vid = job['video'] |
|
|
text_query = job['prompt'] |
|
|
frame_limit = job.get('frame_limit', 60) |
|
|
|
|
|
video_cap = cv2.VideoCapture(source_vid) |
|
|
vid_fps = video_cap.get(cv2.CAP_PROP_FPS) |
|
|
vid_w = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
vid_h = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
|
|
video_frames = [] |
|
|
counter = 0 |
|
|
while video_cap.isOpened(): |
|
|
ret, frame = video_cap.read() |
|
|
if not ret or (frame_limit > 0 and counter >= frame_limit): break |
|
|
video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
|
counter += 1 |
|
|
video_cap.release() |
|
|
|
|
|
session = VID_PROCESSOR.init_video_session(video=video_frames, inference_device=device, dtype=torch.bfloat16) |
|
|
session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=text_query) |
|
|
|
|
|
output_path = os.path.join(HISTORY_DIR, f"{job['id']}_result.mp4") |
|
|
video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), vid_fps, (vid_w, vid_h)) |
|
|
|
|
|
total_frames = len(video_frames) |
|
|
for frame_idx, model_out in enumerate(VID_MODEL.propagate_in_video_iterator(inference_session=session, max_frame_num_to_track=total_frames)): |
|
|
post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out) |
|
|
f_idx = model_out.frame_idx |
|
|
original_pil = Image.fromarray(video_frames[f_idx]) |
|
|
|
|
|
if 'masks' in post_processed: |
|
|
detected_masks = post_processed['masks'] |
|
|
if detected_masks.ndim == 4: detected_masks = detected_masks.squeeze(1) |
|
|
final_frame = apply_mask_overlay(original_pil, detected_masks) |
|
|
else: |
|
|
final_frame = original_pil |
|
|
|
|
|
video_writer.write(cv2.cvtColor(np.array(final_frame), cv2.COLOR_RGB2BGR)) |
|
|
|
|
|
|
|
|
progress = int((frame_idx + 1) / total_frames * 100) |
|
|
processing_results[job['id']]['progress'] = progress |
|
|
|
|
|
video_writer.release() |
|
|
return {'output_path': output_path} |
|
|
|
|
|
@spaces.GPU |
|
|
def process_click_job(job): |
|
|
"""Process click segmentation job""" |
|
|
input_image = job['image'] |
|
|
points_state = job['points'] |
|
|
labels_state = job['labels'] |
|
|
|
|
|
if isinstance(input_image, str): |
|
|
input_image = Image.open(input_image) |
|
|
|
|
|
input_points = [[points_state]] |
|
|
input_labels = [[labels_state]] |
|
|
|
|
|
inputs = TRK_PROCESSOR(images=input_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = TRK_MODEL(**inputs, multimask_output=False) |
|
|
|
|
|
masks = TRK_PROCESSOR.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"], binarize=True)[0] |
|
|
final_img = apply_mask_overlay(input_image, masks[0]) |
|
|
final_img = draw_points_on_image(final_img, points_state) |
|
|
|
|
|
output_path = os.path.join(HISTORY_DIR, f"{job['id']}_result.jpg") |
|
|
final_img.save(output_path) |
|
|
|
|
|
return { |
|
|
'image': final_img, |
|
|
'output_path': output_path |
|
|
} |
|
|
|
|
|
|
|
|
def submit_image_job(source_img, text_query, conf_thresh): |
|
|
"""Submit image segmentation job to background queue""" |
|
|
if source_img is None or not text_query: |
|
|
return None, "❌ Vui lòng cung cấp ảnh và prompt", "" |
|
|
|
|
|
job_id = str(uuid.uuid4()) |
|
|
job = { |
|
|
'id': job_id, |
|
|
'type': 'image', |
|
|
'image': source_img, |
|
|
'prompt': text_query, |
|
|
'conf_thresh': conf_thresh |
|
|
} |
|
|
|
|
|
processing_queue.put(job) |
|
|
return None, f"✅ Đã thêm vào hàng chờ (ID: {job_id[:8]}). Đang xử lý...", job_id |
|
|
|
|
|
def check_image_status(job_id): |
|
|
"""Check status of image processing job""" |
|
|
if not job_id or job_id not in processing_results: |
|
|
return None, "Không tìm thấy công việc" |
|
|
|
|
|
result = processing_results[job_id] |
|
|
|
|
|
if result['status'] == 'processing': |
|
|
return None, f"⏳ Đang xử lý... {result['progress']}%" |
|
|
elif result['status'] == 'completed': |
|
|
return result['result']['image'], "✅ Hoàn thành!" |
|
|
else: |
|
|
return None, f"❌ Lỗi: {result.get('error', 'Unknown')}" |
|
|
|
|
|
def submit_video_job(source_vid, text_query, frame_limit, time_limit): |
|
|
"""Submit video segmentation job to background queue""" |
|
|
if not source_vid or not text_query: |
|
|
return None, "❌ Vui lòng cung cấp video và prompt", "" |
|
|
|
|
|
job_id = str(uuid.uuid4()) |
|
|
job = { |
|
|
'id': job_id, |
|
|
'type': 'video', |
|
|
'video': source_vid, |
|
|
'prompt': text_query, |
|
|
'frame_limit': frame_limit, |
|
|
'time_limit': time_limit |
|
|
} |
|
|
|
|
|
processing_queue.put(job) |
|
|
return None, f"✅ Đã thêm vào hàng chờ (ID: {job_id[:8]}). Đang xử lý...", job_id |
|
|
|
|
|
def check_video_status(job_id): |
|
|
"""Check status of video processing job""" |
|
|
if not job_id or job_id not in processing_results: |
|
|
return None, "Không tìm thấy công việc" |
|
|
|
|
|
result = processing_results[job_id] |
|
|
|
|
|
if result['status'] == 'processing': |
|
|
return None, f"⏳ Đang xử lý... {result['progress']}%" |
|
|
elif result['status'] == 'completed': |
|
|
return result['result']['output_path'], "✅ Hoàn thành!" |
|
|
else: |
|
|
return None, f"❌ Lỗi: {result.get('error', 'Unknown')}" |
|
|
|
|
|
def image_click_handler(image, evt: gr.SelectData, points_state, labels_state): |
|
|
"""Handle click events for interactive segmentation""" |
|
|
x, y = evt.index |
|
|
|
|
|
if points_state is None: points_state = [] |
|
|
if labels_state is None: labels_state = [] |
|
|
|
|
|
points_state.append([x, y]) |
|
|
labels_state.append(1) |
|
|
|
|
|
|
|
|
job_id = str(uuid.uuid4()) |
|
|
job = { |
|
|
'id': job_id, |
|
|
'type': 'click', |
|
|
'image': image, |
|
|
'points': points_state, |
|
|
'labels': labels_state |
|
|
} |
|
|
|
|
|
try: |
|
|
result = process_click_job(job) |
|
|
return result['image'], points_state, labels_state |
|
|
except Exception as e: |
|
|
print(f"Click error: {e}") |
|
|
return image, points_state, labels_state |
|
|
|
|
|
|
|
|
custom_css=""" |
|
|
#col-container { margin: 0 auto; max-width: 1200px; } |
|
|
#main-title h1 { font-size: 2.1em !important; } |
|
|
.history-box { max-height: 600px; overflow-y: auto; } |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css, theme=app_theme) as demo: |
|
|
with gr.Column(elem_id="col-container"): |
|
|
gr.Markdown("# **SAM3: Segment Anything Model 3** 🚀", elem_id="main-title") |
|
|
gr.Markdown("Xử lý ảnh/video với **background processing** - không cần chờ đợi!") |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("📷 Image Segmentation"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
image_input = gr.Image(label="Upload Image", type="pil", height=350) |
|
|
txt_prompt_img = gr.Textbox(label="Text Prompt", placeholder="e.g., cat, face, car wheel") |
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
conf_slider = gr.Slider(0.0, 1.0, value=0.45, step=0.05, label="Confidence Threshold") |
|
|
|
|
|
btn_submit_img = gr.Button("🚀 Submit Job (Background)", variant="primary") |
|
|
btn_check_img = gr.Button("🔍 Check Status", variant="secondary") |
|
|
job_id_img = gr.Textbox(label="Job ID", visible=False) |
|
|
|
|
|
with gr.Column(scale=1.5): |
|
|
image_result = gr.AnnotatedImage(label="Segmented Result", height=410) |
|
|
status_img = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
btn_submit_img.click( |
|
|
fn=submit_image_job, |
|
|
inputs=[image_input, txt_prompt_img, conf_slider], |
|
|
outputs=[image_result, status_img, job_id_img] |
|
|
) |
|
|
|
|
|
btn_check_img.click( |
|
|
fn=check_image_status, |
|
|
inputs=[job_id_img], |
|
|
outputs=[image_result, status_img] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("🎥 Video Segmentation"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
video_input = gr.Video(label="Upload Video", format="mp4", height=320) |
|
|
txt_prompt_vid = gr.Textbox(label="Text Prompt", placeholder="e.g., person running, red car") |
|
|
|
|
|
with gr.Row(): |
|
|
frame_limiter = gr.Slider(10, 500, value=60, step=10, label="Max Frames") |
|
|
time_limiter = gr.Radio([60, 120, 180], value=60, label="Timeout (seconds)") |
|
|
|
|
|
btn_submit_vid = gr.Button("🚀 Submit Job (Background)", variant="primary") |
|
|
btn_check_vid = gr.Button("🔍 Check Status", variant="secondary") |
|
|
job_id_vid = gr.Textbox(label="Job ID", visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
video_result = gr.Video(label="Processed Video") |
|
|
status_vid = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
btn_submit_vid.click( |
|
|
fn=submit_video_job, |
|
|
inputs=[video_input, txt_prompt_vid, frame_limiter, time_limiter], |
|
|
outputs=[video_result, status_vid, job_id_vid] |
|
|
) |
|
|
|
|
|
btn_check_vid.click( |
|
|
fn=check_video_status, |
|
|
inputs=[job_id_vid], |
|
|
outputs=[video_result, status_vid] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("👆 Click Segmentation"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
img_click_input = gr.Image(type="pil", label="Upload Image", interactive=True, height=450) |
|
|
gr.Markdown("**Hướng dẫn:** Click vào đối tượng bạn muốn phân đoạn") |
|
|
|
|
|
with gr.Row(): |
|
|
img_click_clear = gr.Button("🔄 Clear Points & Reset", variant="primary") |
|
|
|
|
|
st_click_points = gr.State([]) |
|
|
st_click_labels = gr.State([]) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
img_click_output = gr.Image(type="pil", label="Result Preview", height=450, interactive=False) |
|
|
|
|
|
img_click_input.select( |
|
|
image_click_handler, |
|
|
inputs=[img_click_input, st_click_points, st_click_labels], |
|
|
outputs=[img_click_output, st_click_points, st_click_labels] |
|
|
) |
|
|
|
|
|
img_click_clear.click( |
|
|
lambda: (None, [], []), |
|
|
outputs=[img_click_output, st_click_points, st_click_labels] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("📜 Lịch Sử Xử Lý"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
btn_refresh_history = gr.Button("🔄 Refresh History", variant="primary") |
|
|
history_display = gr.Markdown(value=get_history_display(), elem_classes="history-box") |
|
|
|
|
|
with gr.Accordion("Hướng dẫn", open=False): |
|
|
gr.Markdown(""" |
|
|
### Lịch sử lưu: |
|
|
- ✅ **Hoàn thành**: File đã được xử lý thành công |
|
|
- ❌ **Lỗi**: Xử lý thất bại |
|
|
- Tất cả file output được lưu trong thư mục `processing_history/` |
|
|
- Hệ thống giữ lại 100 lịch sử gần nhất |
|
|
""") |
|
|
|
|
|
btn_refresh_history.click( |
|
|
fn=get_history_display, |
|
|
outputs=[history_display] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("⚙️ Batch Processing"): |
|
|
gr.Markdown("### Xử lý hàng loạt (Coming Soon)") |
|
|
gr.Markdown(""" |
|
|
Tính năng này sẽ cho phép bạn: |
|
|
- Upload nhiều ảnh/video cùng lúc |
|
|
- Tự động xử lý tuần tự |
|
|
- Download tất cả kết quả dưới dạng ZIP |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
css=custom_css, |
|
|
theme=app_theme, |
|
|
ssr_mode=False, |
|
|
mcp_server=True, |
|
|
show_error=True |
|
|
) |